# -*- coding: utf-8 -*-
Utilities for calculating and plotting linkage disequilbrium.

from __future__ import division, print_function, absolute_import

# third party dependencies
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.transforms as mpl_transforms
import scipy.stats as stats
import scipy.spatial.distance as distance

# internal dependencies
import anhima.loc
from anhima.opt.ld import ld_prune_pairwise_uint8 as _ld_prune_pairwise_uint8

[docs]def pairwise_genotype_ld(gn): """Given a set of genotypes at biallelic variants, calculate the square of the correlation coefficient between all distinct pairs of variants. Parameters ---------- gn : array_like A 2-dimensional array of shape (`n_variants`, `n_samples`) where each element is a genotype call coded as a single integer counting the number of non-reference alleles. Returns ------- r_squared : ndarray, float A 2-dimensional array of squared correlation coefficients between each pair of variants. """ # check input array gn = np.asarray(gn) assert gn.ndim == 2 # TODO deal with missing genotypes return np.corrcoef(gn) ** 2
[docs]def plot_pairwise_ld(r_squared, cmap='Greys', flip=True, ax=None): """Make a classic triangular linkage disequilibrium plot, given an array of pairwise correlation coefficients between variants. Parameters ---------- r_squared : array_like A square 2-dimensional array of squared correlation coefficients between pairs of variants. cmap : color map, optional The color map to use when plotting. Defaults to 'Greys' (0=white, 1=black). flip : bool, optional If True, draw the triangle upside down. ax : axes, optional The axes on which to draw. If not provided, a new figure will be created. Returns ------- ax : axes The axes on which the plot was drawn """ # check inputs r_squared = np.asarray(r_squared) assert r_squared.ndim == 2 assert r_squared.shape[0] == r_squared.shape[1] # setup axes if ax is None: x = plt.rcParams['figure.figsize'][0] fig = plt.figure(figsize=(x, x//2)) ax = fig.add_axes((0, 0, 1, 1)) # define transformation to rotate the colormesh trans = mpl_transforms.Affine2D().rotate_deg_around(0, 0, -45) trans = trans + ax.transData # plot the data as a colormesh ax.pcolormesh(r_squared, cmap=cmap, vmin=0, vmax=1, transform=trans) # cut the plot in half so we see a triangle ax.set_ylim(bottom=0) # turn the triangle upside down if flip: ax.invert_yaxis() # remove axis lines ax.set_axis_off() return ax
[docs]def plot_windowed_ld(gn, pos, window_size, start_position=None, stop_position=None, percentiles=(5, 95), ax=None, median_plot_kwargs=None, percentiles_plot_kwargs=None): """Plot average LD within non-overlapping genome windows. Parameters ---------- gn : array_like A 2-dimensional array of shape (`n_variants`, `n_samples`) where each element is a genotype call coded as a single integer counting the number of non-reference alleles. pos : array_like A 1-dimensional array of genomic positions of variants. window_size : int The size in base-pairs of the windows. start_position : int, optional The start position for the region over which to work. stop_position : int, optional The stop position for the region over which to work. percentiles : sequence of integers, optional Percentiles to plot in addition to the median. ax : axes, optional The axes on which to draw. If not provided, a new figure will be created. median_plot_kwargs : dict, optional Keyword arguments to pass through when plotting the median line. percentiles_plot_kwargs : dict, optional Keyword arguments to pass through when plotting the percentiles. Returns ------- ax : axes The axes on which the plot was drawn. """ # check input array gn = np.asarray(gn) assert gn.ndim == 2 # set up axes if ax is None: x = plt.rcParams['figure.figsize'][0] fig = plt.figure(figsize=(x, x//3)) ax = fig.add_subplot(111) # determine bins if stop_position is None: stop_position = np.max(pos) if start_position is None: start_position = np.min(pos) bin_edges = np.arange(start_position, stop_position, window_size) n_bins = len(bin_edges) - 1 # initialise plotting variables med = np.zeros((n_bins,), dtype='f4') if percentiles: pc = np.zeros((n_bins, len(percentiles)), dtype='f4') # iterate over bins for n in range(n_bins): # determine bin start and stop positions bin_start = bin_edges[n] bin_stop = bin_edges[n + 1] # map genome positions onto variant indices loc = anhima.loc.locate_interval(pos, bin_start, bin_stop) if loc.stop - loc.start > 0: # view genotypes for the current region gw = gn[loc, :] # calculate pairwise LD r_squared = pairwise_genotype_ld(gw) # convert to non-redundance form r_squared_nonredundant = distance.squareform(r_squared, checks=False) # calculate median med[n] = np.median(r_squared_nonredundant) # calculate percentiles if percentiles: for i, p in enumerate(percentiles): pc[n, i] = np.percentile(r_squared_nonredundant, p) # determine x coordinates for plotting, as bin centers x = (bin_edges[1:] + bin_edges[:-1]) / 2 # plot median if median_plot_kwargs is None: median_plot_kwargs = dict() median_plot_kwargs.setdefault('linestyle', '-') median_plot_kwargs.setdefault('color', 'k') median_plot_kwargs.setdefault('linewidth', 2) ax.plot(x, med, **median_plot_kwargs) # plot percentiles if percentiles: if percentiles_plot_kwargs is None: percentiles_plot_kwargs = dict() percentiles_plot_kwargs.setdefault('linestyle', '--') percentiles_plot_kwargs.setdefault('color', 'k') percentiles_plot_kwargs.setdefault('linewidth', 1) for i, p in enumerate(percentiles): ax.plot(x, pc[:, i], **percentiles_plot_kwargs) # tidy up ax.set_xlabel('position') ax.set_ylabel('$r^2$', rotation=0) ax.grid(axis='y') return ax
[docs]def ld_prune_pairwise(gn, window_size=100, window_step=10, max_r_squared=.2): """Given a set of genotypes at biallelic variants, find a subset of the variants which are in approximate linkage equilibrium with each other. Parameters ---------- gn : array_like A 2-dimensional array of shape (`n_variants`, `n_samples`) where each element is a genotype call coded as a single integer counting the number of non-reference alleles. window_size : int, optional The number of variants to work with at a time. window_step : int, optional The number of variants to shift the window by. max_r_squared : float, optional The maximum value of the genotype correlation coefficient, above which variants will be excluded. Returns ------- included : ndarray, bool A boolean array of the same length as the number of variants, where a True value indicates the variant at the corresponding index is included, and a False value indicates the corresponding variant is excluded. Notes ----- The algorithm is as follows. A window of `window_size` variants is taken from the beginning of the genotypes array. The genotype correlation coefficient is calculated between each pair of variants in the window. The first variant in the window is considered, and any other variants in the window with linkage above `max_r_squared` with respect to the first variant is excluded. The next non-excluded variant in the window is then considered, and so on. The window then shifts along by `window_step` variants, and the process is repeated. """ # check input array gn = np.asarray(gn).astype('u1') assert gn.ndim == 2 # use optimised implementation included = _ld_prune_pairwise_uint8(gn, window_size, window_step, max_r_squared) return included
[docs]def pairwise_ld_decay(r_squared, pos, step=1): """Compile data on linkage disequilibrium, separation (in number of variants), and physical distance between pairs of variants. Parameters ---------- r_squared : array_like A square 2-dimensional array of squared correlation coefficients between pairs of variants. pos : array_like A 1-dimensional array of genomic positions of variants. step : int, optional When compiling the data, advance `step` variants. Returns ------- cor : ndarray, float Each element in the array is the squared genotype correlation coefficient between a distinct pair of variants. sep : ndarray, int Each element in the array is the separation (in number of variants) between a distinct pair of variants. dist : ndarray, int Each element in the array is the physical distance between a distinct pair of variants. See Also -------- windowed_ld_decay """ # check inputs r_squared = np.asarray(r_squared) assert r_squared.ndim == 2 assert r_squared.shape[0] == r_squared.shape[1] # determine the number of variants n_variants = r_squared.shape[0] # determine pairs of variants to use pairs = [(i, j) for i in range(0, n_variants, step) for j in range(i+1, n_variants)] # initialise output arrays cor = np.zeros((len(pairs),), dtype=np.float) sep = np.zeros((len(pairs),), dist = np.zeros((len(pairs),), # iterate over pairs for n, (i, j) in enumerate(pairs): cor[n] = r_squared[i, j] sep[n] = j - i dist[n] = np.abs(pos[j] - pos[i]) return cor, sep, dist
[docs]def windowed_ld_decay(gn, pos, window_size, step=1): """Compile data on linkage disequilibrium, separation (in number of variants), and physical distance between pairs of variants. Parameters ---------- gn : array_like A 2-dimensional array of shape (`n_variants`, `n_samples`) where each element is a genotype call coded as a single integer counting the number of non-reference alleles. pos : array_like A 1-dimensional array of genomic positions of variants. window_size : int, optional The number of variants to work with at a time. step : int, optional When compiling the data within each window, advance `step` variants. Returns ------- cor : ndarray, float Each element in the array is the squared genotype correlation coefficient between a distinct pair of variants. sep : ndarray, int Each element in the array is the separation (in number of variants) between a distinct pair of variants. dist : ndarray, int Each element in the array is the physical distance between a distinct pair of variants. See Also -------- pairwise_ld_decay Notes ----- Similar to :func:`pairwise_ld_decay` except that not all pairs of variants are sampled to speed up computation and use less memory. Variants are divided into non-overlapping windows of size `window_size`. Genotype LD is calculated for all pairs within each window. """ # check input array gn = np.asarray(gn) assert gn.ndim == 2 # determine number of variants n_variants = gn.shape[0] # initialise output variables all_cor = list() all_sep = list() all_dist = list() # iterate over non-overlapping windows of variants for window_start in range(0, n_variants, window_size): # determine extent of current window window_stop = min(window_start + window_size, n_variants) # view genotypes for the current window gw = gn[window_start:window_stop, :] # calculate LD r_squared = pairwise_genotype_ld(gw) # compile data cor, sep, dist = pairwise_ld_decay(r_squared, pos, step=step) all_cor.append(cor) all_sep.append(sep) all_dist.append(dist) # concatenate results from each window all_cor = np.concatenate(all_cor) all_sep = np.concatenate(all_sep) all_dist = np.concatenate(all_dist) return all_cor, all_sep, all_dist
[docs]def plot_ld_decay_by_separation(cor, sep, max_separation=100, percentiles=(5, 95), ax=None, median_plot_kwargs=None, percentiles_plot_kwargs=None): """Plot the decay of linkage disequilibrium with separation between variants. Parameters ---------- cor : array_like A 1-dimensional array of squared correlation coefficients between pairs of variants. sep : array_like A 1-dimensional array of separations (in number of variants) between pairs of variants. max_separation : int, optional Maximum separation to consider. percentiles : sequence of integers, optional Percentiles to plot in addition to the median. ax : axes, optional The axes on which to draw. If not provided, a new figure will be created. median_plot_kwargs : dict, optional Keyword arguments to pass through when plotting the median line. percentiles_plot_kwargs : dict, optional Keyword arguments to pass through when plotting the percentiles. Returns ------- ax : axes The axes on which the plot was drawn. """ # check inputs cor = np.asarray(cor) sep = np.asarray(sep) assert cor.ndim == 1 assert sep.ndim == 1 assert cor.shape[0] == sep.shape[0] # set up axes if ax is None: fig, ax = plt.subplots() # set up arrays for plotting cor_median = np.zeros((max_separation,), dtype='f4') if percentiles: cor_percentiles = np.zeros((max_separation, len(percentiles)), dtype='f4') # iterate over separations, compiling data for i in range(max_separation): # view correlations at the given separation c = cor[sep == i] # calculate median and percentiles if len(c) > 0: cor_median[i] = np.median(c) if percentiles: for n, p in enumerate(percentiles): cor_percentiles[i, n] = np.percentile(c, p) # plot the median x = range(max_separation) y = cor_median if median_plot_kwargs is None: median_plot_kwargs = dict() median_plot_kwargs.setdefault('linestyle', '-') median_plot_kwargs.setdefault('color', 'k') median_plot_kwargs.setdefault('linewidth', 2) plt.plot(x, y, label='median', **median_plot_kwargs) # plot percentiles if percentiles: if percentiles_plot_kwargs is None: percentiles_plot_kwargs = dict() percentiles_plot_kwargs.setdefault('linestyle', '--') percentiles_plot_kwargs.setdefault('color', 'k') percentiles_plot_kwargs.setdefault('linewidth', 1) for n, p in enumerate(percentiles): y = cor_percentiles[:, n] plt.plot(x, y, label='%s%%' % p, **percentiles_plot_kwargs) # tidy up ax.set_xlim(left=1, right=max_separation) ax.set_ylim(0, 1) ax.set_xlabel('separation') ax.set_ylabel('$r^2$', rotation=0) ax.grid(axis='y') return ax
[docs]def plot_ld_decay_by_distance(cor, dist, bins, percentiles=(5, 95), ax=None, median_plot_kwargs=None, percentiles_plot_kwargs=None): """Plot the decay of linkage disequilibrium with physical distance between variants. Parameters ---------- cor : array_like A 1-dimensional array of squared correlation coefficients between pairs of variants. dist : array_like A 1-dimensional array of physical distances between pairs of variants. bins : int or sequence of ints Number of bins or bin edges. Bins of distance to calculate LD within. percentiles : sequence of integers, optional Percentiles to plot in addition to the median. ax : axes, optional The axes on which to draw. If not provided, a new figure will be created. median_plot_kwargs : dict, optional Keyword arguments to pass through when plotting the median line. percentiles_plot_kwargs : dict, optional Keyword arguments to pass through when plotting the percentiles. Returns ------- ax : axes The axes on which the plot was drawn. """ # check inputs cor = np.asarray(cor) dist = np.asarray(dist) assert cor.ndim == 1 assert dist.ndim == 1 assert cor.shape[0] == dist.shape[0] # set up axes if ax is None: fig, ax = plt.subplots() # calculate the median of correlation values within bins y, bin_edges, _ = stats.binned_statistic(dist, values=cor, bins=bins, statistic=np.median) # determine x axis variable as bin centers x = (bin_edges[:-1] + bin_edges[1:]) / 2 # plot median if median_plot_kwargs is None: median_plot_kwargs = dict() median_plot_kwargs.setdefault('linestyle', '-') median_plot_kwargs.setdefault('color', 'k') median_plot_kwargs.setdefault('linewidth', 2) ax.plot(x, y, label='median', **median_plot_kwargs) # calculate and plot percentiles if percentiles: if percentiles_plot_kwargs is None: percentiles_plot_kwargs = dict() percentiles_plot_kwargs.setdefault('linestyle', '--') percentiles_plot_kwargs.setdefault('color', 'k') percentiles_plot_kwargs.setdefault('linewidth', 1) for p in percentiles: y, bin_edges, _ = stats.binned_statistic( dist, values=cor, bins=bins, statistic=lambda v: np.percentile(v, p) ) ax.plot(x, y, label='%s%%' % p, **percentiles_plot_kwargs) # tidy up ax.set_xlim(np.min(x), np.max(x)) ax.set_ylim(0, 1) ax.set_xlabel('distance') ax.set_ylabel('$r^2$', rotation=0) ax.grid(axis='y') return ax