Skip to content

Loss Functions Overview¤

DiffBio provides specialized loss functions for training differentiable bioinformatics pipelines.

Available Loss Functions¤

Single-Cell Losses¤

Loss Description
BatchMixingLoss Maximizes batch mixing in latent space
ClusteringCompactnessLoss Encourages tight, well-separated clusters
VelocityConsistencyLoss Ensures RNA velocity consistency
ShannonDiversityLoss Shannon entropy of cluster assignments
SimpsonDiversityLoss Simpson concentration index of assignments

Metric Losses¤

Loss Description
DifferentiableAUROC Sigmoid-approximated AUROC for training
ExactAUROC Exact trapezoidal-rule AUROC for evaluation

Statistical Losses¤

Loss Description
NegativeBinomialLoss NB log-likelihood for count data
VAELoss ELBO loss with KL regularization
HMMLikelihoodLoss HMM forward algorithm loss

Alignment Losses¤

Loss Description
AlignmentScoreLoss Alignment quality loss
AlignmentConsistencyLoss Penalises inconsistent soft alignment paths
SoftEditDistanceLoss Differentiable approximation of edit distance

Using Loss Functions¤

Basic Usage¤

from diffbio.losses import NegativeBinomialLoss

# Create loss function
nb_loss = NegativeBinomialLoss()

# Compute loss (signature: counts, mu, theta)
loss = nb_loss(
    observed_counts,
    model_predictions,
    dispersions,
)

With Training¤

from flax import nnx
from diffbio.losses import BatchMixingLoss

batch_loss = BatchMixingLoss(n_neighbors=15, temperature=1.0)

def train_step(model, data):
    def loss_fn(m):
        result, _, _ = m.apply(data, {}, None)
        return batch_loss(result["embeddings"], data["batch_ids"])

    loss, grads = nnx.value_and_grad(loss_fn)(model)
    return loss, grads

Combining Multiple Losses¤

from diffbio.losses import (
    BatchMixingLoss,
    ClusteringCompactnessLoss,
)

batch_loss = BatchMixingLoss()
cluster_loss = ClusteringCompactnessLoss()

def combined_loss(model, data):
    result, _, _ = model.apply(data, {}, None)

    # Batch mixing (maximize, so negate)
    l_batch = -batch_loss(result["embeddings"], data["batch_ids"])

    # Clustering compactness (minimize)
    l_cluster = cluster_loss(result["embeddings"], result["assignments"])

    # Weighted combination
    return l_batch + 0.5 * l_cluster

Loss Function Interface¤

All DiffBio losses follow a consistent interface:

class Loss:
    def __init__(self, **config):
        """Initialize loss with configuration."""
        pass

    def __call__(self, **inputs) -> jax.Array:
        """Compute loss value.

        Returns:
            Scalar loss value.
        """
        pass

Gradient Properties¤

All losses are designed for:

  • Numerical stability: Using log-space computations where needed
  • Smooth gradients: Temperature-controlled soft operations
  • JAX compatibility: Full support for jax.grad, jax.jit
# Gradient computation
loss_fn = lambda model: loss(model.apply(data, {}, None)[0])
grads = jax.grad(loss_fn)(model)

# JIT compilation
@jax.jit
def compute_loss(model, data):
    result, _, _ = model.apply(data, {}, None)
    return loss(result)

Next Steps¤