Statistical Losses¤
DiffBio provides statistical loss functions for probabilistic models and count data.
NegativeBinomialLoss¤
Negative binomial log-likelihood loss for overdispersed count data.
Overview¤
The negative binomial distribution is commonly used in genomics for modeling RNA-seq counts, which exhibit overdispersion (variance > mean).
Usage¤
from diffbio.losses.statistical_losses import NegativeBinomialLoss
# Create loss function
nb_loss = NegativeBinomialLoss()
# Compute loss (signature: counts, mu, theta)
loss = nb_loss(
observed_counts, # (n_samples, n_genes)
model_means, # (n_samples, n_genes) predicted mean (mu)
gene_dispersions, # (n_genes,) dispersion (theta)
)
Parameters¤
The loss takes no configuration parameters; all inputs are provided at call time.
Algorithm¤
Negative binomial log-likelihood:
Where:
- \(y\) = observed count
- \(\mu\) = predicted mean
- \(\alpha\) = dispersion (overdispersion parameter)
Training Example¤
from flax import nnx
def de_model_loss(model, counts, design):
"""Train differential expression model."""
data = {"counts": counts, "design": design}
result, _, _ = model.apply(data, {}, None)
dispersion = jnp.exp(model.nb_glm.log_dispersion[...])
return nb_loss(
counts,
result["predicted_mean"],
dispersion,
)
grads = nnx.grad(de_model_loss)(model, counts, design)
VAELoss¤
Evidence Lower Bound (ELBO) loss for variational autoencoders.
Overview¤
VAE loss combines reconstruction loss with KL divergence regularization, enabling probabilistic latent space learning.
Usage¤
from diffbio.losses.statistical_losses import VAELoss
# Create loss function
vae_loss = VAELoss(
kl_weight=1.0,
reconstruction_type="mse", # or "bce" for binary cross-entropy
)
# Compute loss (signature: x, x_recon, mean, logvar)
loss = vae_loss(
original_input, # (n_samples, n_features) x
decoded_output, # (n_samples, n_features) x_recon
latent_mean, # (n_samples, latent_dim) encoder mean
latent_log_variance, # (n_samples, latent_dim) encoder logvar
)
Parameters¤
| Parameter | Type | Default | Description |
|---|---|---|---|
kl_weight |
float | 1.0 | KL divergence weight (beta-VAE) |
reconstruction_type |
str | "mse" | Reconstruction loss type ("mse" or "bce") |
Algorithm¤
ELBO loss:
Where:
- \(L_{recon}\) = reconstruction loss (MSE or NB)
- \(D_{KL}\) = KL divergence to prior
- \(\beta\) = KL weight (set > 1 for disentanglement)
KL divergence for Gaussian:
Training Example¤
def vae_training_loss(vae, data):
"""Train VAE normalizer."""
result, _, _ = vae.apply({"counts": data}, {}, None)
return vae_loss(
data,
result["reconstructed"],
result["mu"],
result["log_var"],
)
Beta-VAE Scheduling¤
For better disentanglement:
# Warmup KL weight
def get_kl_weight(epoch, warmup_epochs=100):
return min(1.0, epoch / warmup_epochs)
vae_loss = VAELoss(kl_weight=get_kl_weight(current_epoch))
HMMLikelihoodLoss¤
Negative log-likelihood loss for Hidden Markov Models.
Overview¤
HMM likelihood loss uses the forward algorithm to compute sequence log-likelihood, enabling end-to-end HMM training.
Usage¤
from flax import nnx
from diffbio.losses.statistical_losses import HMMLikelihoodLoss
# Create loss function (HMM parameters are learnable members)
hmm_loss = HMMLikelihoodLoss(n_states=3, n_emissions=4, rngs=nnx.Rngs(42))
# Compute loss (signature: observations)
loss = hmm_loss(observations) # (batch, seq_len) integer-encoded sequences
Parameters¤
| Parameter | Type | Default | Description |
|---|---|---|---|
n_states |
int | required | Number of hidden states |
n_emissions |
int | required | Number of emission symbols |
The loss owns learnable initial-state, transition, and emission log-probabilities; the forward algorithm runs internally on the supplied observations.
Algorithm¤
Returns the mean negative log-likelihood over the batch:
Where \(P(O | \theta)\) is computed via the forward algorithm with logsumexp for stability.
Training Example¤
def hmm_training_loss(hmm_loss_fn, observations):
"""Train HMM model with learnable parameters held inside the loss."""
return hmm_loss_fn(observations)
# Train
grads = nnx.grad(hmm_training_loss)(hmm_loss, train_sequences)
Combining Statistical Losses¤
VAE with NB Reconstruction¤
For scRNA-seq VAE (scVI-style):
def scvi_loss(vae, counts):
"""scVI-style loss with NB reconstruction."""
result, _, _ = vae.apply({"counts": counts}, {}, None)
# NB reconstruction loss (signature: counts, mu, theta)
recon_loss = nb_loss(
counts,
result["reconstructed"],
result["dispersion"],
)
# KL divergence
kl_loss = -0.5 * jnp.mean(
1 + result["log_var"] - result["mu"]**2 - jnp.exp(result["log_var"])
)
return recon_loss + kl_loss
HMM with Emission Learning¤
def hmm_emission_loss(hmm_loss_fn, observations):
"""HMM with learnable initial/transition/emission log-probabilities."""
return hmm_loss_fn(observations)
Numerical Stability¤
All statistical losses use numerically stable implementations:
# Log-space computation
def stable_nb_logprob(y, mu, alpha):
# Use lgamma for log factorial
log_prob = (
jax.scipy.special.gammaln(y + 1/alpha)
- jax.scipy.special.gammaln(y + 1)
- jax.scipy.special.gammaln(1/alpha)
+ (1/alpha) * jnp.log(1/(1 + alpha*mu))
+ y * jnp.log(alpha*mu/(1 + alpha*mu) + 1e-10)
)
return log_prob
Next Steps¤
- See Statistical Operators for NB-GLM and HMM
- Explore Single-Cell Losses for batch correction
- Check Differential Expression Pipeline