Differential Expression Pipeline¤
The DifferentialExpressionPipeline provides DESeq2-style differential expression analysis with end-to-end differentiability.
Differential Expression Fully Differentiable
Overview¤
The differential expression pipeline implements:
- Size Factor Estimation: Median-of-ratios normalization
- Dispersion Estimation: Gene-wise dispersion fitting
- NB-GLM Fitting: Negative binomial generalized linear model
- Statistical Testing: Wald test for significance
- Multiple Testing Correction: Benjamini-Hochberg FDR
Quick Start¤
from flax import nnx
from diffbio.pipelines import (
DifferentialExpressionPipeline,
DEPipelineConfig,
)
# Configure pipeline
config = DEPipelineConfig(
n_genes=2000,
n_conditions=2,
alpha=0.05,
)
# Create pipeline
rngs = nnx.Rngs(42)
de_pipeline = DifferentialExpressionPipeline(config, rngs=rngs)
# Run differential expression analysis
data = {
"counts": count_matrix, # (n_samples, n_genes)
"design": design_matrix, # (n_samples, n_conditions)
}
result, state, metadata = de_pipeline.apply(data, {}, None)
# Get results
log2fc = result["log_fold_change"] # (n_genes,)
pvalues = result["p_values"] # (n_genes,)
significant = result["significant"] # Soft significance indicator
Configuration¤
DEPipelineConfig¤
| Parameter | Type | Default | Description |
|---|---|---|---|
n_genes |
int | 1000 | Number of genes |
n_conditions |
int | 2 | Number of conditions/covariates |
alpha |
float | 0.05 | Significance threshold |
use_size_factors |
bool | True | Whether to compute size factors |
Detailed Configuration¤
config = DEPipelineConfig(
# Data dimensions
n_genes=2000,
n_conditions=2,
# Testing
alpha=0.05,
use_size_factors=True,
)
Pipeline Stages¤
Stage 1: Size Factor Estimation¤
Median-of-ratios method (DESeq2 style):
# Geometric mean per gene
geo_means = jnp.exp(jnp.mean(jnp.log(counts + 1), axis=0))
# Size factors per sample
ratios = counts / geo_means
size_factors = jnp.median(ratios, axis=1)
Stage 2: Dispersion Estimation¤
Gene-wise dispersion with shrinkage:
# Initial dispersion estimate (method of moments)
mean_counts = counts.mean(axis=0)
var_counts = counts.var(axis=0)
alpha_init = (var_counts - mean_counts) / (mean_counts ** 2)
# Shrinkage towards trend
dispersion = shrink_dispersions(alpha_init, mean_counts)
Stage 3: NB-GLM Fitting¤
Fit negative binomial GLM:
\[Y_{ij} \sim NB(\mu_{ij}, \alpha_j)$$
$$\log(\mu_{ij}) = \log(s_i) + X_i \cdot \beta_j\]
Where:
- \(s_i\) = size factor for sample \(i\)
- \(X_i\) = design matrix row
- \(\beta_j\) = coefficients for gene \(j\)
- \(\alpha_j\) = dispersion for gene \(j\)
Stage 4: Wald Test¤
Test for significant coefficients:
# Wald statistic
wald_stat = beta / se_beta
# Two-sided p-value
from jax.scipy.stats import norm
pvalue = 2 * (1 - norm.cdf(jnp.abs(wald_stat)))
Stage 5: Multiple Testing Correction¤
Benjamini-Hochberg FDR:
# Sort p-values
sorted_idx = jnp.argsort(pvalues)
sorted_pvals = pvalues[sorted_idx]
# BH correction
n = len(pvalues)
adjusted = sorted_pvals * n / (jnp.arange(n) + 1)
adjusted = jnp.minimum.accumulate(adjusted[::-1])[::-1]
# Unsort
padj = adjusted[jnp.argsort(sorted_idx)]
Output Format¤
The pipeline returns a dictionary with:
| Key | Shape | Description |
|---|---|---|
size_factors |
(n_samples,) | Sample size factors (median-of-ratios) |
predicted_mean |
(n_samples, n_genes) | NB-GLM predicted mean expression |
log_fold_change |
(n_genes,) | Log2 fold change |
wald_statistic |
(n_genes,) | Wald test statistics |
standard_error |
(n_genes,) | Standard errors of treatment coefficient |
p_values |
(n_genes,) | Raw two-sided p-values |
significant |
(n_genes,) | Soft significance indicator |
Training / Fine-tuning¤
Loss Function for DE¤
from diffbio.losses.statistical_losses import NegativeBinomialLoss
nb_loss = NegativeBinomialLoss()
def de_loss(pipeline, counts, design, known_de_genes):
"""Train DE pipeline with known DE genes."""
data = {"counts": counts, "design": design}
result, _, _ = pipeline.apply(data, {}, None)
# Likelihood loss (NB GLM dispersion is held inside the pipeline module)
dispersion = jnp.exp(pipeline.nb_glm.log_dispersion[...])
likelihood = nb_loss(
counts,
result["predicted_mean"],
dispersion,
)
# Optional: supervised loss if DE genes are known
if known_de_genes is not None:
sig_loss = binary_cross_entropy(
result["significant"].astype(float),
known_de_genes.astype(float),
)
return likelihood + 0.1 * sig_loss
return likelihood
End-to-End Optimization¤
import optax
from flax import nnx
# Create optimizer
optimizer = optax.adam(learning_rate=1e-3)
opt_state = optimizer.init(nnx.state(de_pipeline, nnx.Param))
@jax.jit
def train_step(pipeline, counts, design, opt_state):
def loss_fn(pipe):
data = {"counts": counts, "design": design}
result, _, _ = pipe.apply(data, {}, None)
dispersion = jnp.exp(pipe.nb_glm.log_dispersion[...])
return nb_loss(counts, result["predicted_mean"], dispersion)
loss, grads = nnx.value_and_grad(loss_fn)(pipeline)
params = nnx.state(pipeline, nnx.Param)
updates, opt_state = optimizer.update(grads, opt_state, params)
nnx.update(pipeline, optax.apply_updates(params, updates))
return loss, opt_state
Visualization¤
Volcano Plot¤
import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=(10, 8))
log2fc = result["log_fold_change"]
neg_log_p = -jnp.log10(result["p_values"] + 1e-300)
significant = result["significant"]
# Non-significant
ax.scatter(
log2fc[~significant],
neg_log_p[~significant],
c='gray', alpha=0.5, s=10
)
# Significant
ax.scatter(
log2fc[significant],
neg_log_p[significant],
c='red', alpha=0.7, s=20
)
ax.axhline(-jnp.log10(0.05), ls='--', c='black')
ax.axvline(-1, ls='--', c='black')
ax.axvline(1, ls='--', c='black')
ax.set_xlabel('Log2 Fold Change')
ax.set_ylabel('-Log10 Adjusted P-value')
ax.set_title('Volcano Plot')
plt.show()
MA Plot¤
# Normalize counts by size factors for the MA plot
normalized_counts = data["counts"] / result["size_factors"][:, None]
base_mean = normalized_counts.mean(axis=0)
log2fc = result["log_fold_change"]
plt.figure(figsize=(10, 8))
plt.scatter(jnp.log10(base_mean + 1), log2fc, c='gray', alpha=0.5, s=10)
plt.scatter(
jnp.log10(base_mean[significant] + 1),
log2fc[significant],
c='red', alpha=0.7, s=20
)
plt.axhline(0, ls='--', c='black')
plt.xlabel('Log10 Mean Expression')
plt.ylabel('Log2 Fold Change')
plt.title('MA Plot')
plt.show()
Use Cases¤
| Application | Description |
|---|---|
| RNA-seq DE | Compare gene expression between conditions |
| Single-cell DE | Marker gene discovery |
| Time series | Identify temporally regulated genes |
| Drug response | Find genes responding to treatment |
Next Steps¤
- See Statistical Operators for NB-GLM details
- Explore Statistical Losses for loss functions
- Check Differential Expression Example