.fread <- function(x, ...) {
    .cmd <- ifelse(str_ends(x, "gz"), "gzip -cd ", "cat")
    fread(cmd = .cmd %&% " " %&% x %&% "| sed 's/  / NA /g'", ...)
}

.ggplot <- function(...) {

    ggplot(...) +
        theme_bw() +
        ggplot2::theme(plot.background = element_blank(),
                       plot.margin = unit(c(0,.5,0,.5), 'lines'),
                       strip.background = element_blank(),
                       strip.text = element_text(size=6),
                       legend.background = element_blank(),
                       legend.text = element_text(size = 8),
                       legend.title = element_text(size = 8),
                       axis.title = element_text(size = 8),
                       legend.key.width = unit(1, 'lines'),
                       legend.key.height = unit(.2, 'lines'),
                       legend.key.size = unit(1, 'lines'),
                       axis.line = element_line(color = 'gray20', size = .5),
                       axis.text = element_text(size = 6))
}

Previous results

.read.v3 <- function(hdr) {

    .files.1 <- list.files(hdr %&% "/result/",
                           pattern = "eval", full.names=TRUE)

    .files.2 <- list.files(hdr %&% "/summary/",
                           pattern = "eval.gz", full.names=TRUE)

    .files <- c(.files.1, .files.2)

    ## method p1 p0 pa rseed power10 auprc auroc r r0 s0

    .ct <- c("character", "character","character","character",
             "integer",
             "double","double","double","double","double","double")

    .dat <- lapply(.files, .fread, header=TRUE, fill=TRUE,
                   colClasses = .ct)

    .dat <- .dat %>% 
        do.call(what=rbind) %>%
        mutate(p1 = as.numeric("0." %&% p1)) %>% 
        mutate(pa = as.numeric("0." %&% pa)) %>% 
        mutate(p0 = as.numeric("0." %&% p0))
}

.read.named.v3 <- function(x) {
    xx <-
        str_remove(x, "sim_v3_") %>%
        str_remove_all("[NMS]") %>% 
        str_split("[_B]") %>%
        unlist %>%
        as.integer

    names(xx) <- c("N","M","S","B")
    ret <- .read.v3(x) %>% as.data.table
    ret[, N := xx[1]]
    ret[, M := xx[2]]
    ret[, S := xx[3]]
    ret[, B := xx[4]]
    return(ret)
}

.file <- ".simulation.v3.rdata"
if(!file.exists(.file)) {

    result.dt <-
        list.files(".", "sim_v3") %>%
        lapply(.read.named.v3) %>% 
        do.call(what = rbind)

    save(list="result.dt", file=.file)
} else {
    load(.file)
}

result.dt <- result.dt[N %in% c(20, 40)]
col.lty.df <- 
    tibble(
        method = c("cocoa", "cocoa.pc.safe", "cocoa.pc.full",
                   "tot", "tot.pc.safe", "tot.pc.full", "avg", "mu", "cf"),
        label = c("CoCoA", "CoCoA + safe PC", "CoCoA + 10 PC",
                  "Total", "Tot + safe PC", "Tot + 10 PC",
                  "Average", "Bayesian", "confounder"),
        lty = c(1, 2, 3, 1, 2, 3, 1, 1, 1),
        pch = c(19, 21, 32, 17, 24, 32, 4, 5, 3),
        col = c("magenta", "magenta", "magenta", "#009900", "#009900", "#009900", "#00FF00", "gray40", "#000099")
    )

plot.v3 <- function(.sum) {

    .max.dt <- .sum[, .(y.max = max(y)), by = .(method)]
    .method <- .max.dt[order(.max.dt$y.max, decreasing=TRUE), .(method)] %>% unlist

    .temp <-
        tibble(method = .method) %>% 
        left_join(col.lty.df, by = "method") %>%
        na.omit

    .sum <- .sum %>%
        mutate(method = factor(method, .temp$method, .temp$label))

    .lty <- .temp$lty
    .cols <- .temp$col
    .pch <- .temp$pch

    .aes <- aes(x = p1, y = y,
                ymin = y - y.se,
                ymax = y + y.se,
                colour = method,
                lty = method)

    .df <- .sum %>%
        as_tibble %>% 
        mutate(.p0 = "expr. V[Y|X]=" %&% p0) %>%
        mutate(.pa = "label V[W|X]=" %&% pa)

    .ggplot(.df, .aes) +
        geom_linerange(.aes, lty=1, size=.2) +
        geom_line() +
        geom_point(aes(shape=method), stroke=.3, size=1.5, fill="white") +
        scale_shape_manual(values = .pch) +
        scale_colour_manual(values = .cols) +
        scale_linetype_manual(values = .lty)
}

1. Our counterfactual inference algorithm effectively adjusts confounding factors

.take.cor <- function(.dt, .method = c("cocoa","tot","cf")) {

    .dt[method %in% .method,
        .(y = mean(r), y.se = 2 * sd(r)),
        by = .(p1, p0, pa, method)] %>%
        as.data.table
}

.take.cor.0 <- function(.dt, .method = c("cocoa","tot","cf")) {

    .dt[method %in% .method,
        .(y = mean(r0), y.se = 2 * sd(r0)),
        by = .(p1, p0, pa, method)] %>%
        as.data.table
}

.take.cor.null <- function(.dt, .method = c("cocoa","tot","cf")) {

    .dt[method %in% .method,
        .(y = mean(s0), y.se = 2 * sd(s0)),
        by = .(p1, p0, pa, method)] %>%
        as.data.table
}
for(ss in c(2, 5)) {

    .dt <- result.dt[S==ss & B == (5- ss) & N == 40 & M == 50 & p1 <= .3]

    p1 <- 
        plot.v3(.take.cor(.dt)) +
        xlab("V[Y|W]") +
        ylab("correlation with the unconfounded\n(causal genes)") +
        ggtitle("#confounder: " %&% ss %&%
                ", #non-confounding covariate: " %&% (5 - ss))

    p2 <- 
        plot.v3(.take.cor.0(.dt)) +
        xlab("V[Y|W]") +
        ylab("correlation with the unconfounded\n(non-causal genes)")

    p3 <- 
        plot.v3(.take.cor.null(.dt)) +
        xlab("V[Y|W]") +
        ylab("correlation with the confounders\n(non-causal genes)")

    plt <- p1/p2/p3
    print(plt)

    .file <- fig.dir %&% "/Fig_Q1_S" %&% ss %&% ".pdf"
    .gg.save(.file, plot = plt, width=4, height=6)
}

PDF

PDF

2. CoCoA would not over-correct the variance even if there was no confounding effect

.take.roc <- function(.dt, .method = c("cocoa", "tot", "avg", "mu", "cf")) {

    .ret1 <- .dt[method %in% .method,
                 .(y = mean(auprc), y.se = 2 * sd(auprc)/sqrt(.N)),
                 by = .(p1, p0, pa, method)] %>%
        mutate(metric = "auprc")

    .ret2 <- .dt[method %in% .method,
                 .(y = mean(power10), y.se = 2 * sd(power10)/sqrt(.N)),
                 by = .(p1, p0, pa, method)] %>% 
        mutate(metric = "power10")

    rbind(.ret1, .ret2) %>%
        as.data.table
}
.dt <- result.dt[B==5 & N == 40 & p1 <= .3] %>% .take.roc

p1 <-
    plot.v3(.dt[metric=="power10"]) +
    ylab("Power (FDR < 10%)")

p2 <-
    plot.v3(.dt[metric=="auprc"]) +
    ylab("AUPRC")

plt <- p1/p2
print(plt)

.file <- fig.dir %&% "/Fig_Q2_no_confounding.pdf"
.gg.save(.file, plot = plt, width=4, height=4)

PDF

3. CoCoA improves statistical power

for(ss in c(0, 2, 5)){

    .dt <- result.dt[S==ss & N == 40 & M == 50] %>% .take.roc

    p1 <-
        plot.v3(.dt[metric=="power10"]) +
        ylab("Power (FDR < 10%)") +
        ggtitle("#confounder: " %&% ss)

    p2 <-
        plot.v3(.dt[metric=="auprc"]) +
        ylab("AUPRC")

    plt <- p1/p2
    print(plt)
    .file <- fig.dir %&% "/Fig_Q3_S" %&% ss %&% ".pdf"
    .gg.save(.file, plot = plt, width=4, height=4)
}

PDF

PDF

PDF

4. Additional principal component analysis can help, but over-correction by matrix factorization can hurt

.method <- c("cocoa", "tot", "cocoa.pc.safe", "tot.pc.safe")

.dt <- result.dt[M == 50 & pa == .5 & method %in% .method] %>%
    mutate(N = factor(N, c(40, 20), c("N = 40", "N = 20")))

.aes <- aes(x = as.factor(p1), y = auprc, fill = method)

plt <-
    .ggplot(.dt, .aes) +
    facet_grid(N ~ S) +
    ylab("AUPRC") +
    xlab("Variance Caused by Disease Effect") +
    geom_boxplot(outlier.size = 0, outlier.stroke = 0, size = .2) +
    scale_fill_brewer(palette = "Paired")

print(plt)

.file <- fig.dir %&% "/Fig_Q4_AUPRC.pdf"
.gg.save(.file, plot = plt, width=7, height=5)

PDF

5. What if we increase the sample size?

.methods <- c("cocoa", "tot")
.lab <- c("CoCoA", "Total")

.dt <- result.dt[S == 5 & M == 50 & pa == .5 & p0 == .5 & method %in% .methods] %>%
    mutate(method = factor(method, .methods, .lab))

.aes <- aes(x = as.factor(N),
            y = power10,
            fill = method)

p1 <-
    .gg.plot(.dt) +
    ylab("Power (FDR < 10%)") + xlab("# individuals") +
    facet_wrap(~ p1, nrow = 1) +
    geom_violin(.aes, size = .2, scale="width", draw_quantiles = c(.5)) +
    scale_y_continuous(limits = c(0, 1)) +
    scale_fill_manual(values = c("#f768a1", "gray")) +
    theme(legend.position = "bottom")

.aes <- aes(x = as.factor(N),
            y = auprc,
            fill = method)

p2 <-
    .gg.plot(.dt) +
    ylab("AUPRC") + xlab("# individuals") +
    facet_wrap(~ p1, nrow = 1) +
    geom_violin(.aes, size = .2, scale="width", draw_quantiles = c(.5)) +
    scale_y_continuous(limits = c(0, 1)) +
    scale_fill_manual(values = c("#f768a1", "gray")) +
    theme(legend.position = "bottom")

plt <- p1/p2
print(plt)
## Warning in regularize.values(x, y, ties, missing(ties), na.rm = na.rm):
## collapsing to unique 'x' values

## Warning in regularize.values(x, y, ties, missing(ties), na.rm = na.rm):
## collapsing to unique 'x' values

.file <- fig.dir %&% "/Fig_Q5_sample_size.pdf"
.gg.save(.file, plot = plt, width=6, height=4)
## Warning in regularize.values(x, y, ties, missing(ties), na.rm = na.rm):
## collapsing to unique 'x' values

## Warning in regularize.values(x, y, ties, missing(ties), na.rm = na.rm):
## collapsing to unique 'x' values

PDF