Table of Contents
Inspiration
-
Our previous work demonstrates that fiding causal difference at a pseudobulk level is easier than dealing with noisy cell-level data: Counterfactual inference for single-cell gene expression analysis
-
Interestingly, other people considered identifying batch effects should be treated as a causal effect estimation problem: Batch Effects are Causal Effects: Applications in Human Connectomics
A generative scheme for a single-cell count matrix with multiplicative batch effects
We encounter single-cell expression data consisting of multiple batches. One of the primary goals is to identify cell types (clusters/factors) and cell-type-specific gene expression patterns. However, distinguishing batch-specific and cell-type-specific genes only by a factorization method is challenging and often not identifiable from data alone. For each gene $g$ and cell $j$, the gene expression $Y_{gj}$ were sampled from Poisson distribution with the rate parameter:
$$\lambda_{gj} = \lambda_{gj}^{\textsf{unbiased}} \times \prod_{k} \delta_{gk}^{X_{kj}},$$
affected by the batch effects $\delta_{gk}$. More formally, letting $X_{kj}$ be a batch membership matrix, assigning a cell $j$ to a batch $k$ if and only if $X_{kj}=1$, we assume the average gene expression rates are linearly affected by in the log-transformed space:
$$\mathbb{E}\left[\ln Y_{gj}\right] = \ln \left( \sum_{t} \beta_{gt} \theta_{jt} \right) + \sum_{k} \ln\delta_{gk} X_{kj}.$$
set.seed(1331)
m <- 500 # genes
n <- 1000 # cells
nb <- 2 # batches
## 1. batch membership
X <- matrix(0, n, nb)
batch <- sample(nb, n, replace = TRUE)
for(b in 1:nb){
X[batch == b, b] <- 1
}
## 2. batch effects
W.true <- matrix(rnorm(m*nb), m, nb)
ln.delta <- apply(W.true %*% t(X), 2, scale)
## 3. true effects
K <- 5
.beta <- matrix(rgamma(m * K, 1), m, K)
.theta <- matrix(rgamma(n * K, 1), n, K)
lambda.true <- .beta %*% t(.theta)
lambda <- lambda.true * exp(ln.delta)
yy <- apply(lambda, 2, function(l) sapply(l, rpois, n=1))
oo <- order(apply(t(.theta), 2, which.max))
If we can accurately estimate a true batch effect matrix, say $\delta_{gk}$, it is straightforward to adjust the difference between batches. How can we identify the true batch effect $\delta_{gk}$ for all the genes $g$ specifically expressed in the batch $k$? If we match cells $i$ and $j$ sampled from the batches $a$ and $b$, respectively, we expect the batch-specific difference $\delta_{ga} \neq \delta_{gk}$ will persist and even amplify, but the difference originated from cell types will vanish. This problem is equivalent to estimating the potential outcome of gene expressions in each batch $k$, $\mathbb{E}\left[Y_{gj}^{(k)}\right]$.
A causal inference approach to identify batch effects
To dissect batch-specific effect in a causal inference (potential outcome) framework, we assume our confounding variables $Q$ are well-distributed across different batches:
- Overlap: $0 < p(X_{kj}=1|Q) < 1$ for all $k$.
Moreover, we assume these covariates are sufficient enough to induce conditional dependence between potential (imputed) gene expression and batch assignment mechanisms:
- Strong ignorability: $(Y(k), Y(k')) \perp\perp X | Q$ for all $k,k'$ pairs.
Estimation of the batch effects by matching
Suppose we can counterfactually estimate gene expressions of a certain cell $j$ if the cell was measured in different batches other than the observed batch $k$.
$$Z_{gj} = \frac{ \sum_{i} (1 - X_{ik}) w_{ji} Y_{gi} }{ \sum_{i} (1 - X_{ik}) w_{ji} }$$
Like many other batch correction methods invented for single-cell RNA-seq analysis, we will assume $Z_{gj}$ reliably contain biologically-relevant cell state information while excluding the batch-specific effects to which the cell $j$ belong.
Observed log-likelihood: $$\prod_{j} p(Y_{gj}|\mu_{gs},\delta_{gk},X_{jk}) =\prod_{j} \operatorname{Poisson}(Y_{gj}|\mu_{gs} \sum_{k} \delta_{gk} X_{jk})$$
Counterfactual log-likelihood: $$\prod_{j} p(Z_{gj}|\mu_{gs}, \gamma_{gs}) = \prod_{j} \operatorname{Poisson}(Z_{gj}|\mu_{gs} \gamma_{gs})$$
Local update: Maximize batch $s$-specific parameters
Let's update $\mu_{gs}$ for a gene $g$ in a sample $s$:
$$\mathbb{E}\left[\mu_{gs}\right] \approx \frac{ \sum_{j \in C_{s}} Y_{gj} + \sum_{j \in C_{s}} Z_{gj} }{\sum_{k} \delta_{gk} n_{sk} + n_{s} \gamma_{gs}}$$
Letting $p_{sk} = n_{sk} / n_{s}$,
$$\mu_{gs} \gets \frac{ Y_{gs} + Z_{gs}}{\sum_{k} \delta_{gk} p_{sk} + \gamma_{gs}}$$
If $\delta_{gk} \to 0$ and $p_{sk}=1$, meaning that this sample $s$ is just sampled from the batch $k$ only, $\mu_{gs} \to Y_{gs} + Z_{gs}$ and $Y_{gs} \to Y_{gsk} = 0$. Therefore, $\mu_{gs} \to Z_{gs}$.
Global update
$$\mathbb{E}\left[\delta_{gk}\right] \approx \frac{\sum_{s} \sum_{j \in C_{s}} X_{kj} Y_{gj}}{\sum_{s} \mu_{gs} \sum_{j \in C_{s}} X_{kj}}$$
$$\delta_{gk} \gets \frac{\sum_{s} Y_{gsk} n_{sk}}{\sum_{s} \mu_{gs} n_{sk}}$$
If $Y_{gsk} \to \mu_{gs}$ for all $s$, $\delta_{gk} \to 1$. If $Y_{gsk} < \mu_{gs}$ in all $s$, $\delta_{gk} < 1$. If $Y_{gsk} \to 0$ for all $s$, $\delta_{gk} \to 0$.
Algorithm
-
Initialize batch effect $\delta_{gk} \gets 1$ for each gene $g$ and batch $k$
-
Initialize $\gamma_{gs} \gets 1$ for each sample $s$
-
Static global stat: $S_{gk} \gets 0$
-
For each pseudo-bulk sample $s$ with cells $C_{s}$,
-
$n_{sk} \gets \sum_{j \in C_{s}} X_{kj}$, $n_{s} \gets \sum_{k} n_{sk}$, $p_{sk} \gets n_{sk}/n_{s}$
-
$Y_{gs} \gets \sum_{j \in C_{s}} Y_{gj} / n_{s}$
-
$Y_{gsk} \gets \sum_{j \in C_{s}} Y_{gj} X_{kj} / n_{s}$
-
$Z_{gs} \gets \sum_{j \in C_{s}} Z_{gj} / n_{s}$ after matching and imputation
-
$S_{gk} \gets S_{gk} + Y_{gsk} n_{sk}$
-
-
Iterative-updated global stat: $T_{gk} \gets 0$
-
(Local step) For each PB sample $s$:
-
$\delta_{gs} \gets \sum_{k} \delta_{gk} p_{sk}$
-
$\mu_{gs} \gets (Y_{gs} + Z_{gs}) / (\gamma_{gs} + \delta_{gs})$
-
$\gamma_{gs} \gets (Y_{gs})/(\mu_{gs})$
-
For each $k$: $T_{gk} \gets T_{gk} + \mu_{gs} n_{sk}$
-
-
(Global step) For each batch $k$:
- $\delta_{gk} \gets S_{gk} / T_{gk}$
-
Repeat the previous three steps (5-7) until convergence
A toy example
## 1. project
K <- 5
R <- matrix(rnorm(m * K), K, m)
Q.raw <- R %*% yy # K x n
Before we adjust batch membership in the random projection matrix:
cor(t(Q.raw), X)
## [,1] [,2]
## [1,] 0.7617260 -0.7617260
## [2,] 0.8283630 -0.8283630
## [3,] 0.8099248 -0.8099248
## [4,] -0.7250199 0.7250199
## [5,] 0.6651915 -0.6651915
## 2. regress out
##
## X theta = X inv(X'X) X' Y
## = U D V' V inv(D^2) V' (U D V')' Y
## = U inv(D) V' V D U' Y
## = U U' Y
x.svd <- svd(X)
U <- x.svd$u
U.t <- t(x.svd$u)
Q.t <- t(Q.raw)
Q.t <- Q.t - U %*% U.t %*% Q.t
Q <- t(Q.t)
After we adjust the batch effects:
cor(Q.t, X)
## [,1] [,2]
## [1,] -2.744966e-16 2.744966e-16
## [2,] -6.208765e-16 6.208765e-16
## [3,] 3.233061e-16 -3.233061e-16
## [4,] 1.992261e-16 -1.992261e-16
## [5,] -2.716381e-16 2.716381e-16
q.svd <- svd(Q)
## 3. sorting
B <- (sign(q.svd$v) + 1)/2
ss <- apply(sweep(B, 2, 2^(seq(0,K-1)), `*`), 1, sum) + 1
feat.dn <- apply(Q, 2, function(x) x / sqrt(sum(x^2)))
knn <- 3
d <- nrow(feat.dn)
library(RcppAnnoy)
## a. construct dictionary for each batch
dict.list <- lapply(sort(unique(batch)),
function(b) { new(AnnoyAngular, d) })
for(j in 1:length(batch)){
b <- batch[j]
dict.list[[b]]$addItem(j, feat.dn[,j])
}
for(dd in dict.list){
dd$build(50)
}
## b. a simplified routine to retrieve and estimate counterfactual y
.counterfactual <- function(j){
v <- feat.dn[,j]
nn <- c()
dd <- c()
for(k in 1:nb){
if(k == batch[j]) next
.nn <- dict.list[[k]]$getNNsByVector(v, knn)
.dd <- apply(feat.dn[, .nn], 2, function(u) sum((u - v)^2))
nn <- c(nn, .nn)
dd <- c(dd, .dd)
}
w <- exp(-(dd - max(dd)))
w <- w/sum(w)
yy[, nn, drop = FALSE] %*% matrix(w, ncol=1)
}
ngene <- nrow(yy)
nbatch <- ncol(X)
nsample <- max(ss)
.delta.db <- matrix(1, ngene, nbatch) # gene x batch effects
.delta.num.db <- matrix(0, ngene, nbatch) # gene x batch numerators
.delta.denom.db <- matrix(0, ngene, nbatch) # gene x batch denominators
.prob.bs <- matrix(0, nbatch, nsample) # batch x sample probabilities
.size.bs <- matrix(0, nbatch, nsample) # batch x sample freq
.ybar.ds <- matrix(0, ngene, nsample) # gene x sample observed average
.zbar.ds <- matrix(0, ngene, nsample) # gene x sample imputed average
.mu.ds <- matrix(1, ngene, nsample) # gene x sample adjusted average
## Precalculate some statistics
for(s in 1:nsample){
if(sum(ss == s) < 1) next
.yy <- yy[, ss == s, drop = FALSE]
.zz <- do.call(cbind, lapply(which(ss == s), .counterfactual))
.ybar.ds[,s] <- apply(.yy, 1, mean)
.zbar.ds[,s] <- apply(.zz, 1, mean)
.prob.bs[,s] <- colMeans(X[ss == s, ])
.size.bs[,s] <- colSums(X[ss == s, ])
.y.dsb <- yy[, ss == s, drop = FALSE] %*% X[ss == s, , drop = FALSE]
.delta.num.db <- .delta.num.db + .y.dsb
}
.gamma.ds <- matrix(1, ngene, nsample)
for(iter in 1:100){
.mu.ds <- (.ybar.ds + .zbar.ds) / (.delta.db %*% .prob.bs + .gamma.ds + 1e-8)
.gamma.ds <- .zbar.ds / (.mu.ds + 1e-8)
.delta.db <- .delta.num.db / (.mu.ds %*% t(.size.bs) + 1e-8)
}
Can we recover the original batch effects?
batch #1
3 +-+--------+---------+--------+---------+---+
t | ** |
r | * * * ** ** * |
u 2 + ** ** * * * +
e | * * ***** ** * |
| * * ** * ** ************ ** |
d 1 + ** * ** ** **** ****** **** * +
e | ** ** **** **************** |
l | ***************************** * * |
t 0 + *** ********************** ** +
a | * ************ ** ** **** * * |
-1 + ************** * ***** * +
e | ************ * * * * |
f | **** *** ** * * * |
f -2 + * * * * * +
e | * ** * |
c | * ** |
t +-+--------+---------+--------+---------+---+
0 0.5 1 1.5 2
estimated delta (batch 1)
batch #2
++----------+---------+----------+----------+
t | * |
r | * * |
u | * * *** |
e 2 + * * * ** * * * +
| * *** ** ********* ** |
d | ****** ************ **** |
e | * * ** ************************ * |
l | *** ****** ******* ************* * |
t 0 + ** ***************************** *** * +
a | ********* ******** *** ** ** ** * |
| ***************** ** ** * |
e | *** * ** *** *** |
f -2 + * * * * * +
f | ** |
e | * * |
c | * |
t ++----------+---------+----------+----------+
0 0.5 1 1.5 2
estimated delta (batch 2)
Are they independent of the cell type effects?
batch #1
15 +-+------------+------------+-----------+------------+----+
t | * * |
r | * * * |
u | * |
e 10 + * * * +
| * * ** * * * * * * * |
y | **** ************* *** * * *** * * * |
| ** ** **** ** * * ********** ** * * * |
m | **** *********** * ************ * ******* ** |
e 5 + * ****************** *************** ******** ** +
a | ******* ********************** ********* **** * * * |
n | **** *************** ** ******* *** ******* ** |
| * ** * ** * ** ** * ** ** * * * *** * * |
1 +-+------------+------------+-----------+------------+----+
0 0.5 1 1.5 2
estimated delta 1
batch #2
15 +-+-------------+-------------+-------------+-------------+
t | * * |
r | * |
u | * * |
e | * * * |
10 + * * * * * * * * * +
y | * ** *** * * * ** ** ******* * * ** * |
| * * * ** * *** * ****** ** ** ****** *** **** * |
m | *** ** ******* ******* **** ** ***** *** * **** |
e 5 + ** ********* **************** ****** *** ** ***** * +
a | * ****** ***************** ******* ************* ** |
n | **** ********** * ***** ** ******** ******* ** ****** |
| * * **** * ** ** * * * ******* *** * * * * |
2 +-+-------------+-------------+-------------+-------------+
0 0.5 1 1.5 2
estimated delta 2
While adjusting the estimated batch effects, can we recover the unbiased cell type effects? The following is before adjustment:
batch #1
15 ++-----------+-----------+-----------+-----------+--------+
t | * * |
r | * * * |
u | * |
e 10 + * * * +
| * * * **** * ** * * |
y | * ******* ****** ****** *** * * |
| * * * * * ** ******* *** * ** * ** |
m | * ***** **** ******************** * * |
e 5 + ********* ************** **** * ** ***** +
a | **** ************************* ** *** * |
n | ** *************** * ********** * |
| * ****** ********* *** *** * |
# ++-----------+-----------+-----------+-----------+--------+
0 1 2 3 4
log1p sample mean
batch #2
15 +-+---------+----------+----------+----------+----------+-+
t | * * |
r | * |
u | * * * |
e | * * * |
10 + * * * * ** ** * +
y | * * ******* ******** * * * * |
| * ***************** * * ****** **** |
m | * * **** ****************** * * |
e 5 + *** ************************ *** * * * * +
a | *************************** ***** * ** * |
n | * ********************* **** ** |
| ** ******** * ** *** * |
# +-+---------+----------+----------+----------+----------+-+
0 1 2 3 4 5
log1p sample mean
Here, we adjusted the batch effects:
batch #1
15 +----------+------------+------------+------------+-------+
| * * |
t | * * * |
r | * |
u 10 + * * * +
e | * * **** * * ** |
| * * * ** ********* * ***** ***** * |
y | * * * * **** * ***** *** * * * |
| ** ***************** ***** * * * |
m 5 + * ****** ***************** * ***** ** * * * +
e | **** ********************** **** **** * * |
a | ** * *********************** ** * |
n | * * * ****** ****** * * |
+----------+------------+------------+------------+-------+
1 2 3 4
adjusted sample mean
batch #2
15 +----------+------------+------------+------------+-------+
| * * |
t | * |
r | * * * |
u | * * * |
e 10 + * * * * * * * * * +
| * * ************ * ** * * * * |
y | ** * * *** * ****** *** ** * **** |
| * **** ********************** ** * |
m 5 + * * *** ******** ************ ** ****** * * * +
e | * ** ******************** * ** * * *** * |
a | * * ************************ * * * |
n | * * ** ************* ** * |
+----------+------------+------------+------------+-------+
1 2 3 4
adjusted sample mean
before batch adjustment
+------+------+------+-------+------+------+
20 + *** +
| ****** ****** |
| ********* ******** |
| ** ******* ********* |
10 + ********* ** * ******* +
| **** ******* *********** |
t | ************* *** ********* |
s | ***** *** * ************** |
n 0 + ********** *** ******* ** +
e | **** ***** ************ |
2 | ****** **** ** *** **** * |
-10 + ********* ********* +
| ********* ********* |
| ********* ********* |
| ***** ******** |
-20 + ** ** ** +
+------+------+------+-------+------+------+
-20 -10 0 10 20
tsne1
after batch adjustment
20 +---+-----+------+-----+-----+------+-----++
| *** ** * |
| *** * ****** * |
| ** ** **** ********* |
10 + * *** **** **** * ******** +
t | ** *** ************ ******** **** |
s | ****** ******* ** *** ** *********** |
n | ************** * ***** ** ** ****** |
e 0 + *********** ********************* * * +
2 | ******************** ********* ****** |
| **** *** ********* ******** ***** |
a | ******************************* |
d -10 + ** ********* **** ******** +
j | ** ********************* |
| ***** ***** *** **** |
| * ****** **** **** |
| ** * * |
-20 +---+-----+------+-----+-----+------+-----++
-30 -20 -10 0 10 20 30
tsne1 adj