Statistical Operators¤
DiffBio provides differentiable statistical operators for probabilistic modeling, including Hidden Markov Models, generalized linear models, and EM algorithms.
Statistical Fully Differentiable
Overview¤
Statistical operators enable end-to-end optimization of:
- DifferentiableHMM: Forward algorithm with logsumexp stability
- DifferentiableNBGLM: Negative binomial GLM for differential expression
- DifferentiableEMQuantifier: Unrolled EM for transcript quantification
DifferentiableHMM¤
Hidden Markov Model with differentiable forward algorithm using logsumexp.
Quick Start¤
from flax import nnx
from diffbio.operators.statistical import DifferentiableHMM, HMMConfig
# Configure HMM
config = HMMConfig(
n_states=5,
n_observations=4,
temperature=1.0,
)
# Create operator
rngs = nnx.Rngs(42)
hmm = DifferentiableHMM(config, rngs=rngs)
# Apply to observation sequence
data = {"observations": obs_sequence} # (seq_length, n_observations)
result, state, metadata = hmm.apply(data, {}, None)
# Get results
log_likelihood = result["log_likelihood"]
forward_probs = result["forward_probs"] # Alpha values
backward_probs = result["backward_probs"] # Beta values
posteriors = result["posteriors"] # State posteriors
viterbi_path = result["viterbi_path"] # Most likely path
Configuration¤
| Parameter | Type | Default | Description |
|---|---|---|---|
n_states |
int | 5 | Number of hidden states |
n_observations |
int | 4 | Observation vocabulary size |
temperature |
float | 1.0 | Soft Viterbi temperature |
HMM Parameters¤
The HMM learns three parameter sets:
# Transition probabilities: P(state_t | state_{t-1})
hmm.transition_logits # (n_states, n_states)
# Emission probabilities: P(observation | state)
hmm.emission_logits # (n_states, n_observations)
# Initial state distribution: P(state_0)
hmm.initial_logits # (n_states,)
Forward Algorithm¤
The forward algorithm computes \(\alpha_t(i) = P(o_1, \ldots, o_t, s_t = i)\):
DiffBio uses log-space computation with logsumexp for numerical stability.
DifferentiableNBGLM¤
Negative binomial GLM for modeling overdispersed count data, as used in differential expression analysis.
Quick Start¤
from diffbio.operators.statistical import DifferentiableNBGLM, NBGLMConfig
# Configure NB-GLM
config = NBGLMConfig(
n_genes=2000,
n_covariates=10,
dispersion_model="gene", # or "shared" or "learned"
)
# Create operator
rngs = nnx.Rngs(42)
nbglm = DifferentiableNBGLM(config, rngs=rngs)
# Fit model
data = {
"counts": count_matrix, # (n_samples, n_genes)
"design": design_matrix, # (n_samples, n_covariates)
"size_factors": size_factors, # (n_samples,)
}
result, state, metadata = nbglm.apply(data, {}, None)
# Get results
betas = result["coefficients"] # (n_genes, n_covariates)
dispersions = result["dispersions"] # (n_genes,)
log_likelihood = result["log_likelihood"]
Configuration¤
| Parameter | Type | Default | Description |
|---|---|---|---|
n_genes |
int | 2000 | Number of genes |
n_covariates |
int | 10 | Number of design matrix columns |
dispersion_model |
str | "gene" | Dispersion estimation approach |
Negative Binomial Distribution¤
The NB-GLM models counts as:
Where:
- \(\mu_{ij} = s_i \cdot \exp(X_i \cdot \beta_j)\)
- \(s_i\) = size factor for sample \(i\)
- \(\phi_j\) = dispersion for gene \(j\)
Differential Expression Testing¤
# Compute Wald statistics
wald_stats = result["coefficients"] / result["standard_errors"]
# P-values (two-sided)
from jax.scipy.stats import norm
p_values = 2 * (1 - norm.cdf(jnp.abs(wald_stats)))
# Log2 fold changes
log2fc = result["coefficients"][:, 1] / jnp.log(2) # Assuming condition is column 1
DifferentiableEMQuantifier¤
Unrolled EM algorithm for transcript quantification, inspired by tools like Salmon/kallisto.
Quick Start¤
from diffbio.operators.statistical import DifferentiableEMQuantifier, EMQuantifierConfig
# Configure EM quantifier
config = EMQuantifierConfig(
n_transcripts=50000,
n_iterations=100,
convergence_tol=1e-6,
)
# Create operator
rngs = nnx.Rngs(42)
em_quant = DifferentiableEMQuantifier(config, rngs=rngs)
# Quantify transcripts
data = {
"equivalence_classes": eq_classes, # Read assignments
"counts": eq_counts, # Counts per equivalence class
"effective_lengths": eff_lengths, # Transcript effective lengths
}
result, state, metadata = em_quant.apply(data, {}, None)
# Get TPM values
tpm = result["tpm"] # (n_transcripts,)
counts = result["estimated_counts"] # Estimated counts
Configuration¤
| Parameter | Type | Default | Description |
|---|---|---|---|
n_transcripts |
int | 50000 | Number of transcripts |
n_iterations |
int | 100 | Max EM iterations |
convergence_tol |
float | 1e-6 | Convergence tolerance |
EM Algorithm¤
The EM algorithm alternates:
E-step: Compute expected read assignments $\(p(t | r) = \frac{\rho_t \cdot w_{rt}}{\sum_t \rho_t \cdot w_{rt}}\)$
M-step: Update transcript abundances $\(\rho_t = \frac{\sum_r p(t | r)}{L_t}\)$
DiffBio unrolls EM iterations for end-to-end gradient flow.
Training Statistical Models¤
HMM Training with Forward Loss¤
from diffbio.losses.statistical_losses import HMMLikelihoodLoss
hmm_loss = HMMLikelihoodLoss()
def train_hmm(hmm, observations):
data = {"observations": observations}
result, _, _ = hmm.apply(data, {}, None)
return hmm_loss(result["log_likelihood"])
grads = nnx.grad(train_hmm)(hmm, train_observations)
NB-GLM Training¤
from diffbio.losses.statistical_losses import NegativeBinomialLoss
nb_loss = NegativeBinomialLoss()
def train_nbglm(nbglm, counts, design, size_factors):
data = {
"counts": counts,
"design": design,
"size_factors": size_factors,
}
result, _, _ = nbglm.apply(data, {}, None)
return nb_loss(
counts=counts,
predicted_mean=result["predicted_means"],
dispersion=result["dispersions"],
)
Use Cases¤
| Application | Operator | Description |
|---|---|---|
| Sequence segmentation | DifferentiableHMM | Find sequence regions |
| Gene expression | DifferentiableNBGLM | Differential expression |
| Chromatin states | DifferentiableHMM | Annotate chromatin |
| Transcript quantification | DifferentiableEMQuantifier | RNA-seq quantification |
Next Steps¤
- See Epigenomics Operators for chromatin state HMM
- Explore Differential Expression Pipeline
- Check Statistical Losses for training objectives