Statistical Losses API¤
Loss functions for statistical modeling in bioinformatics.
NegativeBinomialLoss¤
diffbio.losses.statistical_losses.NegativeBinomialLoss
¤
Bases: Module
Negative binomial log-likelihood loss for count data.
The negative binomial distribution is parameterized by mean (mu) and dispersion (theta), suitable for overdispersed count data like RNA-seq.
NB(x | mu, theta) = Gamma(x + theta) / (Gamma(theta) * Gamma(x + 1)) * (theta / (theta + mu))^theta * (mu / (theta + mu))^x
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
eps
|
float
|
Small constant for numerical stability. |
1e-08
|
rngs
|
Rngs | None
|
Flax NNX random number generators. |
None
|
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
eps
|
float
|
Numerical stability constant. |
1e-08
|
rngs
|
Rngs | None
|
Random number generators (not used, for API consistency). |
None
|
__call__
¤
__call__(
counts: Float[Array, "batch genes"],
mu: Float[Array, "batch genes"],
theta: Float[Array, genes],
) -> Float[Array, ""]
Compute negative binomial negative log-likelihood.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
counts
|
Float[Array, 'batch genes']
|
Observed counts. |
required |
mu
|
Float[Array, 'batch genes']
|
Predicted mean. |
required |
theta
|
Float[Array, genes]
|
Dispersion parameter (per gene). |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, '']
|
Mean negative log-likelihood (scalar). |
VAELoss¤
diffbio.losses.statistical_losses.VAELoss
¤
Bases: Module
Variational autoencoder ELBO loss.
Combines reconstruction loss with KL divergence regularization: ELBO = E[log p(x|z)] - KL(q(z|x) || p(z))
For Gaussian encoder and prior: KL = -0.5 * sum(1 + log(var) - mean^2 - var)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
kl_weight
|
float
|
Weight for KL divergence term (beta-VAE). |
1.0
|
reconstruction_type
|
str
|
Type of reconstruction loss ("mse" or "bce"). |
'mse'
|
rngs
|
Rngs | None
|
Flax NNX random number generators. |
None
|
Example
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
kl_weight
|
float
|
Weight for KL term. |
1.0
|
reconstruction_type
|
str
|
"mse" or "bce". |
'mse'
|
rngs
|
Rngs | None
|
Random number generators (not used, for API consistency). |
None
|
__call__
¤
__call__(
x: Float[Array, "batch features"],
x_recon: Float[Array, "batch features"],
mean: Float[Array, "batch latent"],
logvar: Float[Array, "batch latent"],
) -> Float[Array, ""]
Compute VAE ELBO loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Float[Array, 'batch features']
|
Original input. |
required |
x_recon
|
Float[Array, 'batch features']
|
Reconstructed input. |
required |
mean
|
Float[Array, 'batch latent']
|
Encoder mean. |
required |
logvar
|
Float[Array, 'batch latent']
|
Encoder log-variance. |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, '']
|
Negative ELBO (scalar). |
HMMLikelihoodLoss¤
diffbio.losses.statistical_losses.HMMLikelihoodLoss
¤
Bases: Module
HMM negative log-likelihood loss.
Computes the negative log-likelihood of sequences under a Hidden Markov Model using the forward algorithm with logsumexp for stability.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_states
|
int
|
Number of hidden states. |
required |
n_emissions
|
int
|
Number of emission symbols. |
required |
rngs
|
Rngs | None
|
Flax NNX random number generators. |
None
|
Example
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_states
|
int
|
Number of hidden states. |
required |
n_emissions
|
int
|
Number of emission symbols. |
required |
rngs
|
Rngs | None
|
Random number generators. |
None
|
__call__
¤
Compute mean negative log-likelihood over batch.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
observations
|
Int[Array, 'batch seq_len']
|
Batch of integer-encoded sequences. |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, '']
|
Mean negative log-likelihood (scalar). |
Usage Examples¤
Negative Binomial Loss¤
from diffbio.losses import NegativeBinomialLoss
nb_loss = NegativeBinomialLoss()
# For count data (RNA-seq, scRNA-seq)
loss = nb_loss(
counts=observed_counts, # (n_samples, n_genes)
predicted_mean=model_means, # (n_samples, n_genes)
dispersion=dispersions, # (n_genes,) or scalar
)
VAE Loss (ELBO)¤
from diffbio.losses import VAELoss
vae_loss = VAELoss(kl_weight=1.0)
# Evidence lower bound
loss = vae_loss(
x=input_data,
x_reconstructed=decoded,
z_mean=latent_mean,
z_logvar=latent_logvar,
)
HMM Likelihood Loss¤
from diffbio.losses import HMMLikelihoodLoss
hmm_loss = HMMLikelihoodLoss()
# Negative log-likelihood for HMM training
loss = hmm_loss(
observations=observed_sequence,
log_initial=log_initial_probs,
log_transition=log_transition_matrix,
log_emission=log_emission_probs,
)
Mathematical Details¤
Negative Binomial¤
VAE ELBO¤
HMM Forward Algorithm¤
Computed in log-space using logsumexp for numerical stability.