Core Concepts¤
This page covers the ideas you need before working with DiffBio. It focuses on what DiffBio provides — not general JAX or Flax documentation.
Smooth Approximations¤
Traditional bioinformatics algorithms rely on discrete operations that block gradient flow. DiffBio replaces each one with a temperature-controlled smooth approximation:
| Discrete Operation | DiffBio Approximation | Where Used |
|---|---|---|
max(a, b) |
soft_ops.max(x, softness=τ) |
Smith-Waterman recurrence |
argmax(x) |
soft_ops.argmax(x, softness=τ) |
Soft k-means clustering |
x > threshold |
soft_ops.greater(x, t, softness=τ) |
Quality filtering, doublet scoring |
sort(x) |
soft_ops.sort(x, softness=τ) |
Gene ranking, quantile normalization |
top_k(x, k) |
soft_ops.top_k(x, k, softness=τ) |
Feature selection |
| Hard counting | Weighted accumulation | Pileup generation |
| Hard assignment | Soft assignment probabilities | Batch correction, cell annotation |
DiffBio's soft_ops module provides 79 differentiable primitives with
5 smoothness modes (hard, smooth, c0, c1, c2). See the
Soft Operations guide for details.
The temperature parameter τ controls the trade-off:
- Low τ — closer to the discrete algorithm, sharper decisions, weaker gradients
- High τ — smoother output, stronger gradients, easier to optimize
This is DiffBio's core mechanism: every operator in the library uses some form of smooth relaxation to stay differentiable.
The Operator Contract¤
Every DiffBio operator inherits from datarax's OperatorModule and exposes
a single entry point:
| Argument | Type | Typical Value | Purpose |
|---|---|---|---|
data |
dict[str, Array] |
{"counts": jnp.array(...)} |
Input tensors keyed by name |
state |
dict |
{} |
Per-element state (empty for stateless ops) |
metadata |
dict \| None |
None |
Optional metadata |
The result dict contains the original input keys plus new keys added by the operator. This makes chaining trivial — the output of one operator is the input to the next.
# Chain: Impute → Cluster → Pseudotime
result, _, _ = imputer.apply({"counts": counts}, {}, None)
result["embeddings"] = result["imputed_counts"]
result, _, _ = clusterer.apply(result, {}, None)
result, _, _ = pseudotime.apply(result, {}, None)
# result now contains counts, imputed_counts, cluster_assignments, pseudotime, ...
Operator Configuration¤
Each operator has a frozen dataclass config that defines its parameters:
from diffbio.operators.singlecell import SoftKMeansClustering, SoftClusteringConfig
config = SoftClusteringConfig(
n_clusters=10,
n_features=50,
temperature=1.0,
)
operator = SoftKMeansClustering(config, rngs=nnx.Rngs(42))
Configs are separate from the operator so you can serialize, compare, and reproduce configurations independently of model weights.
Operator Domains¤
DiffBio organizes operators by biological domain. Each domain corresponds to
a subpackage under diffbio.operators:
| Domain | Subpackage | Examples |
|---|---|---|
| Single-Cell | singlecell |
Clustering, batch correction, trajectory, imputation, cell annotation |
| Alignment | alignment |
Smith-Waterman, profile HMM, soft MSA |
| Variant Calling | variant |
Pileup, classifier, CNV segmentation |
| Normalization | normalization |
VAE normalizer, UMAP, PHATE |
| Drug Discovery | drug_discovery |
Molecular fingerprints, ADMET, GNN property prediction |
| Epigenomics | epigenomics |
Peak calling, chromatin state annotation |
| Multi-omics | multiomics |
Hi-C contact maps, spatial deconvolution, multi-modal VAE |
| RNA-seq | rnaseq |
Splicing PSI, motif discovery |
| Preprocessing | preprocessing |
Adapter removal, error correction, duplicate filtering |
| Statistical | statistical |
HMM, EM quantification, negative binomial GLM |
Every operator across all domains follows the same apply() contract.
Gradient Flow Through Pipelines¤
The central value proposition of DiffBio: gradients propagate backward through an entire pipeline of operators.
graph TB
A["Input Data"] --> B["Operator A"]
B --> C["Operator B"]
C --> D["Loss Function"]
B -. "∂L/∂params_A" .-> B
C -. "∂L/∂params_B" .-> C
style A fill:#d1fae5,stroke:#059669,color:#064e3b
style B fill:#e0e7ff,stroke:#4338ca,color:#312e81
style C fill:#e0e7ff,stroke:#4338ca,color:#312e81
style D fill:#fef3c7,stroke:#d97706,color:#78350f
In practice this means you can define a loss at the end of a pipeline and optimize all upstream operator parameters jointly:
def pipeline_loss(data):
r1, _, _ = normalizer.apply(data, {}, None)
r2, _, _ = clusterer.apply(r1, {}, None)
return r2["cluster_assignments"].sum()
grad = jax.grad(pipeline_loss)(data)
# grad["counts"] contains ∂loss/∂input through both operators
Every example in the Examples section demonstrates this with a gradient verification step.
Temperature Scheduling¤
Because temperature controls the accuracy-trainability trade-off, a common pattern is annealing — start warm (smooth, easy gradients) and cool toward the discrete solution:
def temperature_schedule(step, initial=10.0, final=0.1, decay_steps=10000):
"""Exponential temperature decay."""
decay_rate = (final / initial) ** (1.0 / decay_steps)
return initial * (decay_rate ** step)
When temperature is learnable (nnx.Param), the optimizer can also discover
the right smoothness level from data.
The Ecosystem¤
DiffBio is part of a layered JAX/NNX scientific ML ecosystem:
| Library | Role | How DiffBio Uses It |
|---|---|---|
| datarax | Execution and data contracts | Base OperatorModule, config system, Batch, data-source patterns |
| artifex | Modeling substrate | Transformer layers, reusable generative-model components, modality-aligned building blocks |
| opifex | Scientific ML substrate | Multi-objective training, field autodiff, operator-learning, and advanced optimization methods |
| calibrax | Evaluation and governance | Metrics, BenchmarkResult, comparison, profiling, storage, regression checks |
DiffBio owns the biology-specific operator layer. Datarax provides the execution contract, Artifex provides reusable model components, Opifex provides scientific-training infrastructure, and Calibrax provides evaluation and benchmark governance.
# DiffBio operator produces latent representations
result, _, _ = vae_normalizer.apply(data, {}, None)
# calibrax evaluates clustering quality
from calibrax.metrics.functional.clustering import silhouette_score
score = silhouette_score(result["latent_mean"], labels)
# artifex provides reusable generative-model loss components
from artifex.generative_models.core.losses.divergence import gaussian_kl_divergence
kl = gaussian_kl_divergence(result["latent_mean"], result["latent_logvar"])
Data Representations¤
DiffBio operators expect JAX arrays in specific formats depending on the domain:
| Data Type | Shape | Description |
|---|---|---|
| Count matrices | (n_cells, n_genes) |
Gene expression counts (single-cell) |
| Embeddings | (n_samples, n_features) |
Latent or reduced representations |
| Sequences (one-hot) | (length, alphabet_size) |
One-hot encoded DNA/RNA/protein |
| Batch labels | (n_samples,) |
Integer batch assignments |
| Spatial coordinates | (n_spots, 2) |
Physical x, y positions |
One-hot encoding is used for sequences because it allows gradients to flow through sequence-dependent operations:
# "ACGT" → one-hot with DNA alphabet (A=0, C=1, G=2, T=3)
seq_indices = jnp.array([0, 1, 2, 3])
seq_onehot = jnp.eye(4)[seq_indices]
# Shape: (4, 4) — each row is a one-hot vector
Count matrices and embeddings are used as-is — no special encoding needed.
Next Steps¤
- Quick Start — run your first operator
- Operators Overview — browse available operators by domain
- Examples — runnable examples with visual outputs
- Differentiable Bioinformatics — deeper theory on smooth relaxations