"""
VCF filtrations and a filterer to apply them.
"""
__author__ = "Janek Sendrowski"
__contact__ = "sendrowski.janek@gmail.com"
__date__ = "2023-05-11"
import functools
import logging
from abc import ABC, abstractmethod
from typing import Iterable, List, Optional, Callable, Dict, Union
import numpy as np
import pandas as pd
from .annotation import DegeneracyAnnotation
from .io_handlers import get_major_base, MultiHandler, get_called_bases, DummyVariant
# get logger
logger = logging.getLogger('fastdfe')
def _count_filtered(func: Callable) -> Callable:
"""
Decorator that increases ``self.n_filtered`` by 1 if the decorated function returns False.
"""
@functools.wraps(func)
def wrapper(self, variant):
"""
Wrapper function.
:param self: Self.
:param variant: The variant to filter.
:return: The result of the decorated function.
"""
result = func(self, variant)
if not result:
self.n_filtered += 1
return result
return wrapper
[docs]
class Filtration(ABC):
"""
Base class for filtering sites based on certain criteria.
"""
#: The number of sites that didn't pass the filter.
n_filtered: int = 0
[docs]
def __init__(self):
"""
Initialize filtration.
"""
#: The logger.
self._logger = logger.getChild(self.__class__.__name__)
#: The handler.
self._handler: MultiHandler | None = None
[docs]
@abstractmethod
@_count_filtered
def filter_site(self, variant: Union['cyvcf2.Variant', DummyVariant]) -> bool:
"""
Filter site.
:param variant: The variant to filter.
:return: ``True`` if the variant should be kept, ``False`` otherwise.
"""
pass
def _setup(self, handler: MultiHandler):
"""
Perform any necessary pre-processing. This method is called before the actual filtration.
:param handler: The handler.
"""
self._handler = handler
def _rewind(self):
"""
Rewind the filtration.
"""
self.n_filtered = 0
def _teardown(self):
"""
Perform any necessary post-processing. This method is called after the actual filtration.
"""
self._logger.info(f"Filtered out {self.n_filtered} sites.")
[docs]
class MaskedFiltration(Filtration, ABC):
"""
Filter sites based on a samples mask.
"""
[docs]
def __init__(
self,
use_parser: bool = True,
include_samples: List[str] | None = None,
exclude_samples: List[str] | None = None
):
"""
Create a new filtration instance.
:param use_parser: Whether to use the samples mask from the parser, if used together with parser.
:param include_samples: The samples to include, defaults to all samples.
:param exclude_samples: The samples to exclude, defaults to no samples.
"""
super().__init__()
#: Whether to use the samples mask from the parser, if used together with parser.
self.use_parser: bool = use_parser
#: The samples to include.
self.include_samples: List[str] | None = include_samples
#: The samples to exclude.
self.exclude_samples: List[str] | None = exclude_samples
#: The samples mask.
self._samples_mask: np.ndarray | None = None
def _prepare_samples_mask(self):
"""
Prepare the samples mask.
"""
from .parser import Parser
if self.use_parser and isinstance(self._handler, Parser):
# use samples mask from parser
self._samples_mask = self._handler._samples_mask
elif self.include_samples is None and self.exclude_samples is None:
# no samples mask
self._samples_mask = None
else:
# determine samples to include
if self.include_samples is None:
mask = np.ones(len(self._handler._reader.samples)).astype(bool)
else:
mask = np.isin(self._handler._reader.samples, self.include_samples)
# determine samples to exclude
if self.exclude_samples is not None:
mask &= ~np.isin(self._handler._reader.samples, self.exclude_samples)
self._samples_mask = mask
def _setup(self, handler: MultiHandler):
"""
Prepare the samples mask.
:param handler: The handler.
"""
super()._setup(handler)
# prepare samples mask
self._prepare_samples_mask()
[docs]
class SNPFiltration(MaskedFiltration):
"""
Only keep SNPs. Note that this entails discarding mono-morphic sites.
"""
[docs]
@_count_filtered
def filter_site(self, variant: Union['cyvcf2.Variant', DummyVariant]) -> bool:
"""
Filter site.
:param variant: The variant to filter.
:return: ``True`` if the variant is an SNP, ``False`` otherwise.
"""
# simply check whether the variant is an SNP if we don't have a samples mask
if self._samples_mask is None or isinstance(variant, DummyVariant):
return variant.is_snp
# otherwise check whether the variant is an SNP among the included samples
return len(np.unique(get_called_bases(variant.gt_bases[self._samples_mask]))) > 1
[docs]
class SNVFiltration(Filtration):
"""
Only keep single site variants (discard indels and MNPs but keep monomorphic sites).
"""
[docs]
@_count_filtered
def filter_site(self, variant: Union['cyvcf2.Variant', DummyVariant]) -> bool:
"""
Filter site.
:param variant: The variant to filter.
:return: ``True`` if the variant is kept, ``False`` otherwise.
"""
return np.all([alt in ['A', 'C', 'G', 'T'] for alt in [variant.REF] + variant.ALT])
[docs]
class PolyAllelicFiltration(MaskedFiltration):
"""
Filter out poly-allelic sites.
"""
[docs]
@_count_filtered
def filter_site(self, variant: Union['cyvcf2.Variant', DummyVariant]) -> bool:
"""
Filter site. Note that we don't check explicitly all alleles, but rather
rely on ``ALT`` field.
:param variant: The variant to filter.
:return: ``True`` if the variant is not poly-allelic, ``False`` otherwise.
"""
# if we don't have a samples mask, simply check whether the variant is poly-allelic
if self._samples_mask is None or isinstance(variant, DummyVariant):
return len(variant.ALT) < 2
# otherwise check whether the variant is poly-allelic among the included samples
return len(np.unique(get_called_bases(variant.gt_bases[self._samples_mask]))) < 3
[docs]
class AllFiltration(Filtration):
"""
Filter out all sites. Only useful for testing purposes.
"""
[docs]
@_count_filtered
def filter_site(self, variant: Union['cyvcf2.Variant', DummyVariant]) -> bool:
"""
Filter site.
:param variant: The variant to filter.
:return: ``False``.
"""
return False
[docs]
class NoFiltration(Filtration):
"""
Do not filter out any sites. Only useful for testing purposes.
"""
[docs]
@_count_filtered
def filter_site(self, variant: Union['cyvcf2.Variant', DummyVariant]) -> bool:
"""
Filter site.
:param variant: The variant to filter.
:return: ``True``.
"""
return True
[docs]
class CodingSequenceFiltration(Filtration):
"""
Filter out sites that are not in coding sequences. This filter should find frequent use when parsing
spectra for DFE inference as we only consider sites in coding sequences for this purpose.
By using it, the annotation and parsing of unnecessary sites can be avoided which increases the speed.
Note that we assume here that within contigs, sites in the GFF file are sorted by position in ascending order.
For this filtration to work, we require a GFF file (passed to :class:`~fastdfe.parser.Parser` or
:class:`~fastdfe.filtration.Filterer`).
"""
[docs]
def __init__(self):
"""
Create a new filtration instance.
"""
Filtration.__init__(self)
#: The coding sequence enclosing the current variant or the closest one downstream.
self.cd: Optional[pd.Series] = None
#: The number of processed sites.
self.n_processed: int = 0
def _setup(self, handler: MultiHandler):
"""
Touch the GFF file to load it.
:param handler: The handler.
"""
# require GFF file
handler._require_gff(self.__class__.__name__)
# setup GFF handler
super()._setup(handler)
# load coding sequences
_ = handler._cds
def _rewind(self):
"""
Rewind the filtration.
"""
super()._rewind()
# reset coding sequence
self.cd = None
[docs]
@_count_filtered
def filter_site(self, v: Union['cyvcf2.Variant', DummyVariant]) -> bool:
"""
Filter site by whether it is in a coding sequence.
:param v: The variant to filter.
:return: ``True`` if the variant is in a coding sequence, ``False`` otherwise.
"""
aliases = self._handler.get_aliases(v.CHROM)
# if self.cd is None or not on the same chromosome or ends before the variant
if self.cd is None or self.cd.seqid not in aliases or v.POS > self.cd.end:
# initialize mock coding sequence
self.cd = pd.Series({
'seqid': v.CHROM,
'start': DegeneracyAnnotation._pos_mock,
'end': DegeneracyAnnotation._pos_mock
})
# find coding sequences downstream
cds = self._handler._cds[self._handler._cds['seqid'].isin(aliases) & (self._handler._cds['end'] >= v.POS)]
if not cds.empty:
# take the first coding sequence
self.cd = cds.iloc[0]
if self.cd.start == v.POS:
self._logger.debug(f'Found coding sequence for {v.CHROM}:{v.POS}.')
else:
self._logger.debug(f'Found coding sequence downstream of {v.CHROM}:{v.POS}.')
if self.n_processed == 0 and self.cd.start == DegeneracyAnnotation._pos_mock:
self._logger.warning(f'No subsequent coding sequence found on the same contig as the first variant. '
f'Please make sure this is the correct GFF file with contig names matching '
f'the VCF file. You can use the aliases parameter to match contig names.')
self.n_processed += 1
# check whether the variant is in the current coding sequence
if self.cd is not None and self.cd.seqid in aliases and self.cd.start <= v.POS <= self.cd.end:
return True
return False
[docs]
class DeviantOutgroupFiltration(Filtration):
"""
Filter out sites where the major allele of the specified outgroup samples differs from the major
allele of the ingroup samples.
"""
[docs]
def __init__(
self,
outgroups: List[str],
ingroups: List[str] = None,
strict_mode: bool = True,
retain_monomorphic: bool = True
):
"""
Construct DeviantOutgroupFiltration.
:param outgroups: The name of the outgroup samples to consider.
:param ingroups: The name of the ingroup samples to consider, defaults to all samples but the outgroups.
:param strict_mode: Whether to filter out sites where no outgroup sample is present, defaults to ``True``.
:param retain_monomorphic: Whether to retain monomorphic sites, defaults to ``True``, which is faster.
"""
super().__init__()
#: The ingroup samples.
self.ingroups: List[str] | None = ingroups
#: The outgroup samples.
self.outgroups: List[str] = outgroups
#: Whether to filter out sites where no outgroup sample is present.
self.strict_mode: bool = strict_mode
#: Whether to retain monomorphic sites.
self.retain_monomorphic: bool = retain_monomorphic
#: The samples found in the VCF file.
self.samples: Optional[np.ndarray] = None
#: The ingroup mask.
self.ingroup_mask: Optional[np.ndarray] = None
#: The outgroup mask.
self.outgroup_mask: Optional[np.ndarray] = None
def _setup(self, handler: MultiHandler):
"""
Touch the reader to load the samples.
:param handler: The handler.
"""
super()._setup(handler)
# create samples array
self.samples: np.ndarray = np.array(handler._reader.samples)
# create ingroup and outgroup masks
self._create_masks()
def _create_masks(self):
"""
Create ingroup and outgroup masks based on the samples.
"""
# create outgroup masks
self.outgroup_mask: np.ndarray = np.isin(self.samples, self.outgroups)
# make sure all outgroups are present
if self.outgroup_mask.sum() != len(self.outgroups):
raise ValueError(f'Not all outgroup samples are present in the VCF file: {self.outgroups}')
# create ingroup mask
if self.ingroups is None:
self.ingroup_mask = ~self.outgroup_mask
else:
self.ingroup_mask = np.isin(self.samples, self.ingroups)
[docs]
@_count_filtered
def filter_site(self, variant: Union['cyvcf2.Variant', DummyVariant]) -> bool:
"""
Filter site.
:param variant: The variant to filter.
:return: ``True`` if the variant should be kept, ``False`` otherwise.
"""
# keep monomorphic sites if requested
if not variant.is_snp and self.retain_monomorphic:
return True
# filter out dummies if retain_monomorphic is false
if isinstance(variant, DummyVariant):
return False
# get major base among ingroup samples
ingroup_base = get_major_base(variant.gt_bases[self.ingroup_mask])
# get major base among outgroup samples
outgroup_base = get_major_base(variant.gt_bases[self.outgroup_mask])
# filter out if no outgroup base is present and strict mode is enabled
if outgroup_base is None:
return not self.strict_mode
# filter out if outgroup base is different from ingroup base
return ingroup_base == outgroup_base
[docs]
class ExistingOutgroupFiltration(Filtration):
"""
Filter out sites for which at least ``n_missing`` of the specified outgroup samples have no called base.
"""
[docs]
def __init__(self, outgroups: List[str], n_missing: int = 1):
"""
Construct ExistingOutgroupFiltration.
:param outgroups: The names of the outgroup samples considered.
:param n_missing: The number of outgroup samples that need to be missing to fail the filter.
"""
super().__init__()
#: The outgroup samples.
self.outgroups: List[str] = outgroups
#: Minimum number of missing outgroups required to filter out a site.
self.n_missing: int = n_missing
#: The samples found in the VCF file.
self.samples: Optional[np.ndarray] = None
#: The outgroup mask.
self.outgroup_mask: Optional[np.ndarray] = None
def _setup(self, handler: MultiHandler):
"""
Touch the reader to load the samples.
:param handler: The handler.
"""
super()._setup(handler)
# create samples array
self.samples: np.ndarray = np.array(handler._reader.samples)
# create outgroup mask
self._create_mask()
def _create_mask(self):
"""
Create outgroup mask based on the samples.
"""
self.outgroup_mask: np.ndarray = np.isin(self.samples, self.outgroups)
[docs]
@_count_filtered
def filter_site(self, variant: Union['cyvcf2.Variant', DummyVariant]) -> bool:
"""
Filter site.
:param variant: The variant to filter.
:return: ``True`` if the variant should be kept, ``False`` otherwise.
"""
# keep dummy variants
if isinstance(variant, DummyVariant):
return True
# get outgroup genotypes
outgroups = variant.gt_bases[self.outgroup_mask]
# count how many outgroups have no called base
missing_count = sum(len(get_called_bases(outgroup)) == 0 for outgroup in outgroups)
# filter out if at least n outgroups are missing
return missing_count < self.n_missing
[docs]
class BiasedGCConversionFiltration(Filtration):
"""
Only retain A<->T and G<->C substitutions (which are unaffected
by biased gene conversion, see [CITGB]_).
Mono-allelic sites are always retained, and we assume sites are at most bi-allelic. Note that the number of
mutational target sites is reduced by this filtration.
.. [CITGB] Pouyet et al., 'Background selection and biased
gene conversion affect more than 95% of the human genome and bias demographic inferences.',
Elife, 7:e36317, 2018
"""
[docs]
@_count_filtered
def filter_site(self, variant: Union['cyvcf2.Variant', DummyVariant]) -> bool:
"""
Remove bi-allelic sites that are not A<->T or G<->C mutations.
:param variant: The variant to filter.
:return: ``True`` if the variant should be kept, ``False`` otherwise.
"""
if variant.is_snp and len(variant.ALT) > 0:
return (variant.REF, variant.ALT[0]) in [('A', 'T'), ('T', 'A'), ('G', 'C'), ('C', 'G')]
return True
[docs]
class ContigFiltration(Filtration):
"""
Filter out sites that are not on the specified contigs.
"""
[docs]
def __init__(self, contigs: List[str]):
"""
Construct ContigFiltration.
:param contigs: The contigs to retain.
"""
super().__init__()
#: The contigs to retain.
self.contigs: List[str] = contigs
[docs]
@_count_filtered
def filter_site(self, variant: Union['cyvcf2.Variant', DummyVariant]) -> bool:
"""
Filter site.
:param variant: The variant to filter.
:return: ``True`` if the variant is on one of the specified contigs, ``False`` otherwise.
"""
return variant.CHROM in self.contigs
[docs]
class Filterer(MultiHandler):
"""
Filter a VCF file using a list of filtrations.
Example usage:
::
import fastdfe as fd
# only keep variants in coding sequences
f = fd.Filterer(
vcf="http://ftp.1000genomes.ebi.ac.uk/vol1/ftp/data_collections/"
"1000_genomes_project/release/20181203_biallelic_SNV/"
"ALL.chr21.shapeit2_integrated_v1a.GRCh38.20181129.phased.vcf.gz",
gff="http://ftp.ensembl.org/pub/release-109/gff3/homo_sapiens/"
"Homo_sapiens.GRCh38.109.chromosome.21.gff3.gz",
output='sapiens.chr21.coding.vcf.gz',
filtrations=[fd.CodingSequenceFiltration()],
aliases=dict(chr21=['21'])
)
f.filter()
"""
[docs]
def __init__(
self,
vcf: str | Iterable['cyvcf2.Variant'],
output: str,
gff: str | None = None,
filtrations: List[Filtration] = [],
info_ancestral: str = 'AA',
max_sites: int = np.inf,
seed: int | None = 0,
cache: bool = True,
aliases: Dict[str, List[str]] = {}
):
"""
Create a new filter instance.
:param vcf: The VCF file, possibly gzipped or a URL.
:param output: The output file.
:param gff: The GFF file, possibly gzipped or a URL. This argument is required for some filtrations.
:param filtrations: The filtrations.
:param info_ancestral: The info field for the ancestral allele.
:param max_sites: The maximum number of sites to process.
:param seed: The seed for the random number generator. Use ``None`` for no seed.
:param cache: Whether to cache files downloaded from urls.
:param aliases: Dictionary of aliases for the contigs in the VCF file, e.g. ``{'chr1': ['1']}``.
"""
super().__init__(
vcf=vcf,
gff=gff,
info_ancestral=info_ancestral,
max_sites=max_sites,
seed=seed,
cache=cache,
aliases=aliases
)
#: The filtrations.
self.filtrations: List[Filtration] = filtrations
#: The output file.
self.output: str = output
#: The number of sites that did not pass the filters.
self.n_filtered: int = 0
#: The VCF writer.
self._writer: 'cyvcf2.Writer' | None = None
[docs]
def is_filtered(self, variant: Union['cyvcf2.Variant', DummyVariant]) -> bool:
"""
Whether the given variant is kept.
:param variant: The variant to check.
:return: ``True`` if the variant is kept, ``False`` otherwise.
"""
# filter the variant
for filtration in self.filtrations:
if not filtration.filter_site(variant):
self.n_filtered += 1
return False
return True
def _setup(self):
"""
Set up the filtrations.
"""
try:
from cyvcf2 import Writer
except ImportError:
raise ImportError(
"VCF support in fastdfe requires the optional 'cyvcf2' package. "
"Please install fastdfe with the 'vcf' extra: pip install fastdfe[vcf]"
)
# setup filtrations
for f in self.filtrations:
f._setup(self)
# create the writer
self._writer = Writer(self.output, self._reader)
def _teardown(self):
"""
Tear down the filtrations.
"""
for f in self.filtrations:
f._teardown()
# close the writer and reader
self._writer.close()
self._reader.close()
[docs]
def filter(self):
"""
Filter the VCF.
"""
self._logger.info('Start filtering')
# setup filtrations
self._setup()
# get progress bar
with self.get_pbar(desc=f"{self.__class__.__name__}>Processing sites") as pbar:
# iterate over the sites
for i, variant in enumerate(self._reader):
if self.is_filtered(variant):
# write the variant
self._writer.write_record(variant)
pbar.update()
# explicitly stopping after ``n`` sites fixes a bug with cyvcf2:
# 'error parsing variant with `htslib::bcf_read` error-code: 0 and ret: -2'
if i + 1 == self.n_sites or i + 1 == self.max_sites:
break
# teardown filtrations
self._teardown()
self._logger.info(f'Filtered out {self.n_filtered} of {self.n_sites} sites in total.')