Metric Losses¤
DiffBio provides differentiable metric-based loss functions that can serve as training objectives, plus exact evaluation metrics backed by calibrax.
DifferentiableAUROC¤
Smooth training surrogate for the Area Under the ROC Curve. Replaces the hard indicator in the Wilcoxon-Mann-Whitney statistic with a sigmoid function, making AUROC fully differentiable and JIT-compatible.
For every (positive, negative) pair, the hard AUROC checks whether the positive score exceeds the negative score. This module replaces that indicator with \(\sigma((s_+ - s_-) / T)\), yielding a smooth surrogate whose gradient can drive optimisation.
Quick Start¤
from diffbio.losses.metric_losses import DifferentiableAUROC
auroc_loss = DifferentiableAUROC(temperature=1.0)
predictions = jnp.array([0.9, 0.8, 0.1, 0.2])
labels = jnp.array([1.0, 1.0, 0.0, 0.0])
value = auroc_loss(predictions, labels) # ~1.0 for well-separated predictions
Parameters¤
| Parameter | Type | Default | Description |
|---|---|---|---|
temperature |
float | 1.0 | Sigmoid sharpness (lower = closer to hard indicator) |
Usage as Training Loss¤
from flax import nnx
def auroc_training_loss(model, data, labels):
result, _, _ = model.apply(data, {}, None)
predictions = result["predictions"]
# Maximize AUROC -> minimize negative AUROC
return -auroc_loss(predictions, labels)
grads = nnx.grad(auroc_training_loss)(model, data, labels)
Algorithm¤
Where \(P\) and \(N\) are the sets of positive and negative samples.
ExactAUROC¤
Exact AUROC metric using calibrax's trapezoidal-rule implementation. Delegates to calibrax.metrics.functional.classification.roc_auc for the threshold-sweep and trapezoidal-rule computation.
Use this for evaluation only. The sorting-based trapezoidal rule has zero gradients with respect to predictions because argsort is not differentiable.
Quick Start¤
from diffbio.losses.metric_losses import ExactAUROC
exact = ExactAUROC()
predictions = jnp.array([0.9, 0.8, 0.1, 0.2])
labels = jnp.array([1.0, 1.0, 0.0, 0.0])
value = exact(predictions, labels) # 1.0 (exact)
Parameters¤
ExactAUROC has no learnable parameters.
When to Use Which¤
| Metric | Differentiable | Use Case |
|---|---|---|
| DifferentiableAUROC | Yes | Training objective for gradient-based optimizers |
| ExactAUROC | No | Evaluation, reporting, model selection |
Training Pattern¤
A typical workflow uses the differentiable variant for training and the exact variant for evaluation:
from diffbio.losses.metric_losses import DifferentiableAUROC, ExactAUROC
# Training
train_auroc = DifferentiableAUROC(temperature=1.0)
def train_loss(model, data, labels):
result, _, _ = model.apply(data, {}, None)
return -train_auroc(result["scores"], labels)
# Evaluation
eval_auroc = ExactAUROC()
def evaluate(model, data, labels):
result, _, _ = model.apply(data, {}, None)
return eval_auroc(result["scores"], labels)
Next Steps¤
- See Single-Cell Losses for single-cell training objectives
- Explore Statistical Losses for count-based losses