Loss Functions Overview
DiffBio provides specialized loss functions for training differentiable bioinformatics pipelines.
Available Loss Functions
Single-Cell Losses
Metric Losses
Statistical Losses
Alignment Losses
| Loss |
Description |
AlignmentScoreLoss |
Alignment quality loss |
AlignmentConsistencyLoss |
Penalises inconsistent soft alignment paths |
SoftEditDistanceLoss |
Differentiable approximation of edit distance |
Using Loss Functions
Basic Usage
from diffbio.losses import NegativeBinomialLoss
# Create loss function
nb_loss = NegativeBinomialLoss()
# Compute loss (signature: counts, mu, theta)
loss = nb_loss(
observed_counts,
model_predictions,
dispersions,
)
With Training
from flax import nnx
from diffbio.losses import BatchMixingLoss
batch_loss = BatchMixingLoss(n_neighbors=15, temperature=1.0)
def train_step(model, data):
def loss_fn(m):
result, _, _ = m.apply(data, {}, None)
return batch_loss(result["embeddings"], data["batch_ids"])
loss, grads = nnx.value_and_grad(loss_fn)(model)
return loss, grads
Combining Multiple Losses
from diffbio.losses import (
BatchMixingLoss,
ClusteringCompactnessLoss,
)
batch_loss = BatchMixingLoss()
cluster_loss = ClusteringCompactnessLoss()
def combined_loss(model, data):
result, _, _ = model.apply(data, {}, None)
# Batch mixing (maximize, so negate)
l_batch = -batch_loss(result["embeddings"], data["batch_ids"])
# Clustering compactness (minimize)
l_cluster = cluster_loss(result["embeddings"], result["assignments"])
# Weighted combination
return l_batch + 0.5 * l_cluster
Loss Function Interface
All DiffBio losses follow a consistent interface:
class Loss:
def __init__(self, **config):
"""Initialize loss with configuration."""
pass
def __call__(self, **inputs) -> jax.Array:
"""Compute loss value.
Returns:
Scalar loss value.
"""
pass
Gradient Properties
All losses are designed for:
- Numerical stability: Using log-space computations where needed
- Smooth gradients: Temperature-controlled soft operations
- JAX compatibility: Full support for
jax.grad, jax.jit
# Gradient computation
loss_fn = lambda model: loss(model.apply(data, {}, None)[0])
grads = jax.grad(loss_fn)(model)
# JIT compilation
@jax.jit
def compute_loss(model, data):
result, _, _ = model.apply(data, {}, None)
return loss(result)
Next Steps