from typing import Optional
from glimix_core.lmm import LMM
from numpy import (
asarray,
atleast_1d,
atleast_2d,
concatenate,
inf,
linspace,
ones,
sqrt,
stack,
)
from numpy.linalg import cholesky
from numpy_sugar import ddot
from numpy_sugar.linalg import economic_qs_linear, economic_svd
from tqdm import tqdm
from ._math import PMat, QSCov, ScoreStatistic
[docs]class CellRegMap:
"""
Mixed-model with genetic effect heterogeneity.
The CellRegMap model can be cast as:
๐ฒ = W๐ + ๐ ๐ฝโ + ๐ โ๐โ + ๐ + ๐ฎ + ๐, (1)
where:
๐โ ~ ๐(๐, ๐โ๐ดโ๐ดโแต),
๐ ~ ๐(๐, ๐โฯโ๐ดโ๐ดโแต),
๐ฎ ~ ๐(๐, ๐โ(1-ฯโ)๐บโ๐ดโ๐ดโแต), and
๐ ~ ๐(๐, ๐โ๐ธ).
๐ โ๐ is a random effect term which models the GxC effect.
Additionally, W๐ models additive covariates and ๐ ๐ฝโ models persistent genetic effects.
Both are modelled as fixed effects.
On the other hand, ๐, ๐ฎ and ๐ are modelled as random effects
๐ is the environment effect, ๐ฎ is a background term accounting for interactions between population structure
and environmental structure, and ๐ is the iid noise.
The full covariance of ๐ฒ is therefore given by:
cov(๐ฒ) = ๐โ๐ณ๐ดโ๐ดโแต๐ณ + ๐โฯโ๐ดโ๐ดโแต + ๐โ(1-ฯโ)๐บโ๐ดโ๐ดโแต + ๐โ๐ธ,
where ๐ณ = diag(๐ ). Its marginalised form is given by:
๐ฒ ~ ๐(W๐ + ๐ ๐ฝโ, ๐โ๐ณ๐ดโ๐ดโแต๐ณ + ๐โ(ฯโ๐ดโ๐ดโแต + (1-ฯโ)๐บโ๐ดโ๐ดโแต) + ๐โ๐ธ).
The CellRegMap method is used to perform an interaction test:
The interaction test compares the following hypotheses (from Eq. 1):
๐โ: ๐โ = 0
๐โ: ๐โ > 0
๐โ denotes no GxE effects, while ๐โ models the presence of GxE effects.
"""
def __init__(self, y, E, W=None, Ls=None, E1=None, hK=None):
self._y = asarray(y, float).flatten()
self._E0 = asarray(E, float)
Ls = [] if Ls is None else Ls
if W is not None:
self._W = asarray(W, float)
else:
self._W = ones((self._y.shape[0], 1))
if E1 is not None:
self._E1 = asarray(E1, float)
else:
self._E1 = asarray(E, float)
self._Ls = list(asarray(L, float) for L in Ls)
assert self._W.ndim == 2
assert self._E0.ndim == 2
assert self._E1.ndim == 2
assert self._y.shape[0] == self._W.shape[0]
assert self._y.shape[0] == self._E0.shape[0]
assert self._y.shape[0] == self._E1.shape[0]
for L in Ls:
assert self._y.shape[0] == L.shape[0]
assert L.ndim == 2
self._null_lmm_assoc = {}
self._halfSigma = {}
self._Sigma_qs = {}
# TODO: remove it after debugging
self._Sigma = {}
# option to set different background (when Ls are defined, background is K*EEt + EEt)
if len(Ls) == 0:
# self._rho0 = [1.0]
if hK is None: # EEt only as background
self._rho1 = [1.0]
self._halfSigma[1.0] = self._E1
self._Sigma_qs[1.0] = economic_qs_linear(self._E1, return_q1=False)
else: # hK is decomposition of K, background in this case is K + EEt
self._rho1 = linspace(0, 1, 11)
for rho1 in self._rho1:
a = sqrt(rho1)
b = sqrt(1 - rho1)
hS = concatenate([a * self._E1] + [b * hK], axis=1)
self._halfSigma[rho1] = hS
self._Sigma_qs[rho1] = economic_qs_linear(
self._halfSigma[rho1], return_q1=False
)
else:
# self._rho0 = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
self._rho1 = linspace(0, 1, 11)
for rho1 in self._rho1:
# ฮฃ = ฯโ๐ด๐ดแต + (1-ฯโ)๐บโE
# concatenate((sqrt(rho1) * self._E, sqrt(1 - rho1) * G1), axis=1)
# self._Sigma[rho1] = rho1 * self._EE + (1 - rho1) * self._K
# self._Sigma_qs[rho1] = economic_qs(self._Sigma[rho1])
a = sqrt(rho1)
b = sqrt(1 - rho1)
hS = concatenate([a * self._E1] + [b * L for L in Ls], axis=1)
self._halfSigma[rho1] = hS
self._Sigma_qs[rho1] = economic_qs_linear(
self._halfSigma[rho1], return_q1=False
)
@property
def n_samples(self):
return self._y.shape[0]
def predict_interaction(self, G, MAF):
"""
Estimate effect sizes for a given set of SNPs
"""
# breakpoint()
G = asarray(G, float)
E0 = self._E0
W = self._W
n_snps = G.shape[1]
beta_g_s = []
beta_gxe_s = []
p = asarray(atleast_1d(MAF), float)
normalization = 1 / sqrt(2 * p * (1 - p))
for i in range(n_snps):
g = G[:, [i]]
# mean(๐ฒ) = W๐ + ๐ ๐ฝโ + ๐ด๐ฒ = ๐ผ๐
M = concatenate((W, g, E0), axis=1)
gE = g * E0
best = {"lml": -inf, "rho1": 0}
hSigma_p = {}
Sigma_qs = {}
for rho1 in self._rho1:
# ฮฃ[ฯโ] = ฯโ(๐ โ๐ด)(๐ โ๐ด)แต + (1-ฯโ)๐บโEEแต
a = sqrt(rho1)
b = sqrt(1 - rho1)
hSigma_p[rho1] = concatenate(
[a * gE] + [b * L for L in self._Ls], axis=1
)
# (
# (a * gE, b * self._G), axis=1
# )
# cov(๐ฒ) = ๐โฮฃ[ฯโ] + ๐โ๐ธ
# lmm = Kron2Sum(Y, [[1]], M, hSigma_p[rho1], restricted=True)
Sigma_qs[rho1] = economic_qs_linear(
hSigma_p[rho1], return_q1=False
)
lmm = LMM(self._y, M, Sigma_qs[rho1], restricted=True)
lmm.fit(verbose=False)
if lmm.lml() > best["lml"]:
best["lml"] = lmm.lml()
best["rho1"] = rho1
best["lmm"] = lmm
# breakpoint()
lmm = best["lmm"]
# beta_g = ๐ฝโ
beta_g = lmm.beta[W.shape[1]]
# yadj = ๐ฒ - ๐ผ๐
yadj = (self._y - lmm.mean()).reshape(self._y.shape[0], 1)
rho1 = best["rho1"]
v1 = lmm.v0
v2 = lmm.v1
hSigma_p_qs = economic_qs_linear(hSigma_p[rho1], return_q1=False)
qscov = QSCov(hSigma_p_qs[0][0], hSigma_p_qs[1], v1, v2)
# v = cov(๐ฒ)โปยน(๐ฒ - ๐ผ๐)
v = qscov.solve(yadj)
sigma2_gxe = v1 * rho1
beta_gxe = sigma2_gxe * E0 @ (gE.T @ v) * normalization[i]
# beta_star = (beta_g * normalization + beta_gxe)
beta_g_s.append(beta_g)
beta_gxe_s.append(beta_gxe)
return (asarray(beta_g_s), stack(beta_gxe_s).T)
def estimate_aggregate_environment(self, g):
g = atleast_2d(g).reshape((g.size, 1))
E0 = self._E0
gE = g * E0
W = self._W
M = concatenate((W, g, E0), axis=1)
best = {"lml": -inf, "rho1": 0}
hSigma_p = {}
for rho1 in self._rho1:
# ฮฃโ = ฯโ(๐ โ๐ด)(๐ โ๐ด)แต + (1-ฯโ)๐บโE
a = sqrt(rho1)
b = sqrt(1 - rho1)
hSigma_p[rho1] = concatenate([a * gE] + [b * L for L in self._Ls], axis=1)
# cov(๐ฒ) = ๐โฮฃโ + ๐โ๐ธ
# lmm = Kron2Sum(Y, [[1]], M, hSigma_p[rho1], restricted=True)
QS = self._Sigma_qs[rho1]
lmm = LMM(self._y, M, QS, restricted=True)
lmm.fit(verbose=False)
if lmm.lml() > best["lml"]:
best["lml"] = lmm.lml()
best["rho1"] = rho1
best["lmm"] = lmm
lmm = best["lmm"]
yadj = self._y - lmm.mean()
# rho1 = best["rho1"]
v1 = lmm.v0
v2 = lmm.v1
rho1 = best["rho1"]
sigma2_gxe = rho1 * v1
hSigma_p_qs = economic_qs_linear(hSigma_p[rho1], return_q1=False)
qscov = QSCov(hSigma_p_qs[0][0], hSigma_p_qs[1], v1, v2)
# v = cov(๐ฒ)โปยนyadj
v = qscov.solve(yadj)
beta_gxe = sigma2_gxe * gE.T @ v
return E0 @ beta_gxe
def scan_association(self, G):
info = {"rho1": [], "e2": [], "g2": [], "eps2": []}
# NULL model
best = {"lml": -inf, "rho1": 0}
for rho1 in self._rho1:
QS = self._Sigma_qs[rho1]
# LRT for fixed effects requires ML rather than REML estimation
lmm = LMM(self._y, self._W, QS, restricted=False)
lmm.fit(verbose=False)
if lmm.lml() > best["lml"]:
best["lml"] = lmm.lml()
best["rho1"] = rho1
best["lmm"] = lmm
null_lmm = best["lmm"]
info["rho1"].append(best["rho1"])
info["e2"].append(null_lmm.v0 * best["rho1"])
info["g2"].append(null_lmm.v0 * (1 - best["rho1"]))
info["eps2"].append(null_lmm.v1)
n_snps = G.shape[1]
alt_lmls = []
for i in tqdm(range(n_snps)):
g = G[:, [i]]
X = concatenate((self._W, g), axis=1)
QS = self._Sigma_qs[best["rho1"]]
alt_lmm = LMM(self._y, X, QS, restricted=False)
alt_lmm.fit(verbose=False)
alt_lmls.append(alt_lmm.lml())
pvalues = lrt_pvalues(null_lmm.lml(), alt_lmls, dof=1)
info = {key: asarray(v, float) for key, v in info.items()}
return asarray(pvalues, float), info
def scan_association_fast(self, G):
info = {"rho1": [], "e2": [], "g2": [], "eps2": []}
# NULL model
best = {"lml": -inf, "rho1": 0}
for rho1 in self._rho1:
QS = self._Sigma_qs[rho1]
# LRT for fixed effects requires ML rather than REML estimation
lmm = LMM(self._y, self._W, QS, restricted=False)
lmm.fit(verbose=False)
if lmm.lml() > best["lml"]:
best["lml"] = lmm.lml()
best["rho1"] = rho1
best["lmm"] = lmm
null_lmm = best["lmm"]
info["rho1"].append(best["rho1"])
info["e2"].append(null_lmm.v0 * best["rho1"])
info["g2"].append(null_lmm.v0 * (1 - best["rho1"]))
info["eps2"].append(null_lmm.v1)
# Alternative model
lmm = null_lmm
flmm = lmm.get_fast_scanner()
alt_lmls = flmm.fast_scan(G, verbose=False)['lml']
pvalues = lrt_pvalues(null_lmm.lml(), alt_lmls, dof=1)
info = {key: asarray(v, float) for key, v in info.items()}
return asarray(pvalues, float), info
def scan_interaction(
self, G, idx_E: Optional[any] = None, idx_G: Optional[any] = None
):
"""
๐ฒ = W๐ + ๐ ๐ฝโ + ๐ โ๐โ + ๐ + ๐ฎ + ๐
[fixed=X] [H1]
๐โ ~ ๐(๐, ๐โ๐ดโ๐ดโแต),
๐ ~ ๐(๐, ๐โฯโ๐ดโ๐ดโแต),
๐ฎ ~ ๐(๐, ๐โ(1-ฯโ)๐บโ๐ดโ๐ดโแต), and
๐ ~ ๐(๐, ๐โ๐ธ).
๐โ: ๐โ = 0
๐โ: ๐โ > 0
"""
# TODO: make sure G is nxp
from chiscore import davies_pvalue
G = asarray(G, float)
n_snps = G.shape[1]
pvalues = []
info = {"rho1": [], "e2": [], "g2": [], "eps2": []}
for i in tqdm(range(n_snps)):
g = G[:, [i]]
X = concatenate((self._W, g), axis=1)
best = {"lml": -inf, "rho1": 0}
# Null model fitting: find best (๐, ๐ฝโ, ๐โ, ๐โ, ฯโ)
for rho1 in self._rho1:
# QS = self._Sigma_qs[rho1]
# halfSigma = self._halfSigma[rho1]
# ฮฃ = ฯโ๐ด๐ดแต + (1-ฯโ)๐บโE
# cov(yโ) = ๐โฮฃ + ๐โI
QS = self._Sigma_qs[rho1]
lmm = LMM(self._y, X, QS, restricted=True)
lmm.fit(verbose=False)
if lmm.lml() > best["lml"]:
best["lml"] = lmm.lml()
best["rho1"] = rho1
best["lmm"] = lmm
lmm = best["lmm"]
# H1 via score test
# Let Kโ = eยฒ๐ด๐ดแต + gยฒ๐บโE + ๐ยฒI
# eยฒ=๐โฯโ
# gยฒ=๐โ(1-ฯโ)
# ๐ยฒ=๐โ
# with optimal values ๐โ and ๐โ found above.
info["rho1"].append(best["rho1"])
info["e2"].append(lmm.v0 * best["rho1"])
info["g2"].append(lmm.v0 * (1 - best["rho1"]))
info["eps2"].append(lmm.v1)
# QS = economic_decomp( ฮฃ(ฯโ) )
Q0 = self._Sigma_qs[best["rho1"]][0][0]
S0 = self._Sigma_qs[best["rho1"]][1]
# e2 = best["lmm"].v0 * best["rho1"]
# g2 = best["lmm"].v0 * (1 - best["rho1"])
# eps2 = best["lmm"].v1
# EE = self._E @ self._E.T
# K = self._G @ self._G.T
# K0 = e2 * EE + g2 * K + eps2 * eye(K.shape[0])
qscov = QSCov(
Q0,
S0,
lmm.v0, # ๐โ
lmm.v1, # ๐โ
)
# start = time()
# qscov = QSCov(self._Sigma_qs[best["rho1"]], lmm.C0[0, 0], lmm.C1[0, 0])
# print(f"Elapsed: {time() - start}")
# X = concatenate((self._E, g), axis=1)
# X = concatenate((self._W, g), axis=1)
# Let Pโ = Kโโปยน - KโโปยนX(XแตKโโปยนX)โปยนXแตKโโปยน.
P = PMat(qscov, X)
# P0 = inv(K0) - inv(K0) @ X @ inv(X.T @ inv(K0) @ X) @ X.T @ inv(K0)
# Pโ๐ฒ = Kโโปยน๐ฒ - KโโปยนX(XแตKโโปยนX)โปยนXแตKโโปยน๐ฒ.
# Useful for permutation
if idx_E is None:
E0 = self._E0
else:
E0 = self._E0[idx_E, :]
# The covariance matrix of H1 is K = Kโ + ๐โdiag(๐ )โ
๐ด๐ดแตโ
diag(๐ )
# We have โK/โ๐โ = diag(๐ )โ
๐ด๐ดแตโ
diag(๐ )
# The score test statistics is given by
# Q = ยฝ๐ฒแตPโโ
โKโ
Pโ๐ฒ
# start = time()
# Useful for permutation
if idx_G is None:
gtest = g.ravel()
else:
gtest = g.ravel()[idx_G]
ss = ScoreStatistic(P, qscov, ddot(gtest, E0))
Q = ss.statistic(self._y)
# import numpy as np
# deltaK = np.diag(gtest) @ EE @ np.diag(gtest)
# Q_ = 0.5 * self._y.T @ P0 @ deltaK @ P0 @ self._y
# print(f"Elapsed: {time() - start}")
# Q is the score statistic for our interaction test and follows a linear
# combination
# of chi-squared (df=1) distributions:
# Q โผ โฮปฯยฒ, where ฮปแตข are the non-zero eigenvalues of ยฝโPโโ
โKโ
โPโ.
# Since eigenvals(๐ฐ๐ฐแต) = eigenvals(๐ฐแต๐ฐ) (TODO: find citation),
# we can compute ยฝ(โโK)Pโ(โโK) instead.
# start = time()
# import scipy as sp
# sqrtm = sp.linalg.sqrtm
# np.linalg.eigvalsh(0.5 * sqrtm(P0) @ deltaK @ sqrtm(P0))
# np.linalg.eigvalsh(0.5 * sqrtm(deltaK) @ P0 @ sqrtm(deltaK))
# TODO: compare with Liu approximation, maybe try a computational intensive
# method
pval, pinfo = davies_pvalue(Q, ss.matrix_for_dist_weights(), True)
pvalues.append(pval)
# print(f"Elapsed: {time() - start}")
info = {key: asarray(v, float) for key, v in info.items()}
return asarray(pvalues, float), info
def lrt_pvalues(null_lml, alt_lmls, dof=1):
"""
Compute p-values from likelihood ratios.
These are likelihood ratio test p-values.
Parameters
----------
null_lml : float
Log of the marginal likelihood under the null hypothesis.
alt_lmls : array_like
Log of the marginal likelihoods under the alternative hypotheses.
dof : int
Degrees of freedom.
Returns
-------
pvalues : ndarray
P-values.
"""
from numpy import clip
from numpy_sugar import epsilon
from scipy.stats import chi2
lrs = clip(-2 * null_lml + 2 * asarray(alt_lmls, float), epsilon.super_tiny, inf)
pv = chi2(df=dof).sf(lrs)
return clip(pv, epsilon.super_tiny, 1 - epsilon.tiny)
[docs]def run_association(y, W, E, G, hK=None):
"""
Association test.
Test for persistent genetic effects.
Compute p-values using a likelihood ratio test.
Parameters
----------
y : array
Phenotype
W : array
Fixed effect covariates
E : array
Cellular contexts
G : array
Genotypes (expanded)
hK : array
decompositon of kinship matrix (expanded)
Returns
-------
pvalues : ndarray
P-values.
"""
if hK is None: hK = None
crm = CellRegMap(y, W, E, hK=hK)
pv = crm.scan_association(G)
return pv
def run_association_fast(y, W, E, G, hK=None):
"""
Association test.
Test for persistent genetic effects.
Compute p-values using a likelihood ratio test.
Parameters
----------
y : array
Phenotype
W : array
Fixed effect covariates
E : array
Cellular contexts
G : array
Genotypes (expanded)
hK : array
decompositon of kinship matrix (expanded)
Returns
-------
pvalues : ndarray
P-values.
"""
if hK is None: hK = None
crm = CellRegMap(y, W, E, hK=hK)
pv = crm.scan_association_fast(G)
return pv
def get_L_values(hK, E):
"""
As the definition of Ls is not particulatly intuitive,
function to extract list of L values given kinship K and
cellular environments E
"""
# get eigendecomposition of EEt
[U, S, _] = economic_svd(E)
us = U * S
# get decomposition of K \odot EEt
Ls = [ddot(us[:,i], hK) for i in range(us.shape[1])]
return Ls
[docs]def run_interaction(y, E, G, W=None, E1=None, E2=None, hK=None, idx_G=None):
"""
Interaction test.
Test for cell-level genetic effects due to GxC interactions.
Compute p-values using a score test.
Parameters
----------
y : array
Phenotype
E : array
Cellular contexts (GxC component)
G : array
Genotypes (expanded)
W : array
Fixed effect covariates
hK : array
decompositon of kinship matrix (expanded)
E1 : array
Cellular contexts (C component)
E2 : array
Cellular contexts (K*C component)
idx_G : array
Permuted genotype index
Returns
-------
pvalues : ndarray
P-values.
"""
if E1 is None: E1 = E
else: E1 = E1
if E2 is None: E2 = E
else: E2 = E2
if hK is None: Ls = None
else: Ls = get_L_values(hK, E2)
crm = CellRegMap(y=y, E=E, W=W, E1=E1, Ls=Ls)
pv = crm.scan_interaction(G, idx_G)
return pv
def compute_maf(X):
r"""Compute minor allele frequencies.
It assumes that ``X`` encodes 0, 1, and 2 representing the number
of alleles (or dosage), or ``NaN`` to represent missing values.
Parameters
----------
X : array_like
Genotype matrix.
Returns
-------
array_like
Minor allele frequencies.
Examples
--------
.. doctest::
>>> from numpy.random import RandomState
>>> from limix.qc import compute_maf
>>>
>>> random = RandomState(0)
>>> X = random.randint(0, 3, size=(100, 10))
>>>
>>> print(compute_maf(X)) # doctest: +FLOAT_CMP
[0.49 0.49 0.445 0.495 0.5 0.45 0.48 0.48 0.47 0.435]
"""
import dask.array as da
import xarray as xr
from pandas import DataFrame
from numpy import isnan, logical_not, minimum, nansum
if isinstance(X, da.Array):
s0 = da.nansum(X, axis=0).compute()
denom = 2 * (X.shape[0] - da.isnan(X).sum(axis=0)).compute()
elif isinstance(X, DataFrame):
s0 = X.sum(axis=0, skipna=True)
denom = 2 * logical_not(X.isna()).sum(axis=0)
elif isinstance(X, xr.DataArray):
if "sample" in X.dims:
kwargs = {"dim": "sample"}
else:
kwargs = {"axis": 0}
s0 = X.sum(skipna=True, **kwargs)
denom = 2 * logical_not(isnan(X)).sum(**kwargs)
else:
s0 = nansum(X, axis=0)
denom = 2 * logical_not(isnan(X)).sum(axis=0)
s0 = s0 / denom
s1 = 1 - s0
maf = minimum(s0, s1)
if hasattr(maf, "name"):
maf.name = "maf"
return maf
[docs]def estimate_betas(y, W, E, G, maf=None, E1=None, E2=None, hK=None):
"""
Effect sizes estimator
Estimates cell-level genetic effects due to GxC
as well as persistent genetic effects across all cells.
Parameters
----------
y : array
Phenotype
W : array
Fixed effect covariates
E : array
Cellular contexts
G : array
Genotypes (expanded)
maf: array
Minor allele frequencies (MAFs) for the SNPs in G
hK : array
decompositon of kinship matrix (expanded)
E1 : array
Cellular contexts (C component)
E2 : array
Cellular contexts (K*C component)
Returns
-------
betas : ndarray
estimated effect sizes, both persistent and due to GxC.
"""
if E1 is None: E1 = E
else: E1 = E1
if E2 is None: E2 = E
else: E2 = E2
if hK is None: Ls = None
else: Ls = get_L_values(hK, E2)
crm = CellRegMap(y=y, E=E, W=W, E1=E1, Ls=Ls)
if maf is None:
maf = compute_maf(G)
# print("MAFs: {}".format(maf))
betas = crm.predict_interaction(G, maf)
return betas