Normalization Operators¤
DiffBio provides differentiable normalization operators for count data, dimensionality reduction, and sequence embeddings.
Normalization Fully Differentiable
Overview¤
Normalization operators enable end-to-end optimization of:
- VAENormalizer: scVI-style VAE for count normalization
- DifferentiableUMAP: Differentiable UMAP dimensionality reduction
- DifferentiablePHATE: Differentiable PHATE dimensionality reduction
- SequenceEmbedding: Learned sequence embeddings
VAENormalizer¤
Variational autoencoder for count data normalization, inspired by scVI.
Quick Start¤
from flax import nnx
from diffbio.operators.normalization import VAENormalizer, VAENormalizerConfig
# Configure VAE normalizer
config = VAENormalizerConfig(
n_genes=2000,
latent_dim=10,
hidden_dim=128,
n_layers=2,
)
# Create operator
rngs = nnx.Rngs(42)
vae_normalizer = VAENormalizer(config, rngs=rngs)
# Apply to count data
data = {"counts": raw_counts} # (n_cells, n_genes)
result, state, metadata = vae_normalizer.apply(data, {}, None)
# Get normalized output
normalized = result["normalized"] # Normalized counts
latent = result["latent"] # Latent representation
reconstructed = result["reconstructed"] # Reconstructed counts
Configuration¤
| Parameter | Type | Default | Description |
|---|---|---|---|
n_genes |
int | 2000 | Number of genes |
latent_dim |
int | 10 | Latent space dimension |
hidden_dim |
int | 128 | Encoder/decoder hidden dimension |
n_layers |
int | 2 | Number of hidden layers |
VAE Architecture¤
graph LR
A["Counts"] --> ENC["Encoder"]
A --> LIB["Log-library size"]
ENC --> MS["μ, σ"]
MS --> Z["z ~ N(μ, σ²)"]
Z --> DEC["Decoder"]
LIB --> DEC
DEC --> OUT["Normalized Counts"]
style A fill:#d1fae5,stroke:#059669,color:#064e3b
style ENC fill:#e0e7ff,stroke:#4338ca,color:#312e81
style MS fill:#dbeafe,stroke:#2563eb,color:#1e3a5f
style Z fill:#ede9fe,stroke:#7c3aed,color:#4c1d95
style DEC fill:#e0e7ff,stroke:#4338ca,color:#312e81
style LIB fill:#dbeafe,stroke:#2563eb,color:#1e3a5f
style OUT fill:#d1fae5,stroke:#059669,color:#064e3b
The VAE learns to:
- Remove technical variation (library size, batch effects)
- Preserve biological variation in latent space
- Output normalized expression values
DifferentiableUMAP¤
Differentiable UMAP for dimensionality reduction with gradient flow.
Quick Start¤
from diffbio.operators.normalization import DifferentiableUMAP, UMAPConfig
# Configure UMAP
config = UMAPConfig(
n_components=2,
n_neighbors=15,
input_features=50,
hidden_dim=32,
)
# Create operator
rngs = nnx.Rngs(42)
umap = DifferentiableUMAP(config, rngs=rngs)
# Apply dimensionality reduction
data = {"features": high_dim_data} # (n_samples, n_features)
result, state, metadata = umap.apply(data, {}, None)
# Get low-dimensional embedding
embedding = result["embedding"] # (n_samples, n_components)
Configuration¤
| Parameter | Type | Default | Description |
|---|---|---|---|
n_components |
int | 2 | Output embedding dimension |
n_neighbors |
int | 15 | Number of neighbors for local structure |
input_features |
int | 64 | Input feature dimension |
hidden_dim |
int | 32 | Projection network hidden dimension |
metric |
str | "euclidean" | Distance metric ("euclidean" or "cosine") |
UMAP Loss Function¤
The differentiable UMAP optimizes a cross-entropy loss:
\[L = -\sum_{ij} [p_{ij} \log q_{ij} + (1-p_{ij}) \log(1-q_{ij})]\]
Where:
- \(p_{ij}\) = high-dimensional similarity (fuzzy set membership)
- \(q_{ij}\) = low-dimensional similarity
Learnable Parameters¤
# UMAP curve parameters
umap.embedding_head.curve_params # Packed [a, b] kernel coefficients
# Projection network
umap.embedding_head.projection_backbone # Direct Artifex MLP backbone
DifferentiablePHATE¤
Differentiable PHATE (Potential of Heat-diffusion for Affinity-based Trajectory Embedding) for dimensionality reduction with end-to-end gradient flow. Particularly well-suited for trajectory-structured data in single-cell analysis.
Quick Start¤
from diffbio.operators.normalization import DifferentiablePHATE, PHATEConfig
config = PHATEConfig(
n_components=2,
n_neighbors=5,
decay=40.0,
diffusion_t=10,
gamma=1.0,
)
rngs = nnx.Rngs(42)
phate = DifferentiablePHATE(config, rngs=rngs)
data = {"features": high_dim_data} # (n_samples, n_features)
result, state, metadata = phate.apply(data, {}, None)
embedding = result["embedding"] # (n_samples, n_components)
potential_distances = result["potential_distances"] # (n_samples, n_samples)
diffusion_op = result["diffusion_operator"] # M^t matrix
Configuration¤
| Parameter | Type | Default | Description |
|---|---|---|---|
n_components |
int | 2 | Output embedding dimensions |
n_neighbors |
int | 5 | Neighbors for local bandwidth |
decay |
float | 40.0 | Alpha-decaying kernel exponent (higher = sharper) |
diffusion_t |
int | 10 | Diffusion time (matrix power) |
gamma |
float | 1.0 | Informational distance constant (1=log, 0=sqrt) |
input_features |
int | 64 | Input feature dimension |
hidden_dim |
int | 32 | Projection network hidden dimension |
PHATE Algorithm¤
- Pairwise distances between samples
- Alpha-decay affinity kernel: \(K(i,j) = \exp(-\alpha \cdot (d / \sigma_i)^2)\) with adaptive bandwidth
- Symmetrize and row-normalize to Markov matrix \(M\)
- Diffusion \(M^t\) via eigendecomposition
- Potential distances: \(-\log(M^t + \epsilon)\) for \(\gamma=1\)
- Classical MDS on the potential distance matrix for low-dimensional embedding
Use Cases¤
- Visualizing developmental trajectories in single-cell data
- Embedding data with branching structures
- Alternative to UMAP when trajectory preservation is important
SequenceEmbedding¤
Learned embeddings for biological sequences.
Quick Start¤
from diffbio.operators.normalization import SequenceEmbedding, SequenceEmbeddingConfig
# Configure embedding
config = SequenceEmbeddingConfig(
alphabet_size=4,
max_length=100,
embedding_dim=64,
n_layers=2,
)
# Create operator
rngs = nnx.Rngs(42)
seq_embed = SequenceEmbedding(config, rngs=rngs)
# Get sequence embeddings
data = {"sequences": sequences} # (n_seqs, seq_length, alphabet_size)
result, state, metadata = seq_embed.apply(data, {}, None)
# Get embeddings
embeddings = result["embeddings"] # (n_seqs, embedding_dim)
Configuration¤
| Parameter | Type | Default | Description |
|---|---|---|---|
alphabet_size |
int | 4 | Input alphabet size |
max_length |
int | 100 | Maximum sequence length |
embedding_dim |
int | 64 | Output embedding dimension |
n_layers |
int | 2 | Number of encoder layers |
Embedding Architecture¤
graph LR
A["One-hot Sequence"] --> B["Position Encoding"]
B --> C["Transformer Layers"]
C --> D["Global Pool"]
D --> E["Embedding"]
style A fill:#d1fae5,stroke:#059669,color:#064e3b
style B fill:#e0e7ff,stroke:#4338ca,color:#312e81
style C fill:#ede9fe,stroke:#7c3aed,color:#4c1d95
style D fill:#e0e7ff,stroke:#4338ca,color:#312e81
style E fill:#d1fae5,stroke:#059669,color:#064e3b
Training with Normalization¤
VAE Training¤
from diffbio.losses.statistical_losses import VAELoss
vae_loss = VAELoss(kl_weight=1.0)
def train_vae(normalizer, counts):
data = {"counts": counts}
result, _, _ = normalizer.apply(data, {}, None)
# Reconstruction + KL loss
loss = vae_loss(
reconstructed=result["reconstructed"],
target=counts,
mu=result["mu"],
log_var=result["log_var"],
)
return loss
UMAP Training¤
def train_umap(umap, features):
data = {"features": features}
result, _, _ = umap.apply(data, {}, None)
# UMAP cross-entropy loss
p_ij = result["high_dim_similarities"]
q_ij = result["low_dim_similarities"]
loss = umap._compute_umap_loss(p_ij, q_ij)
return loss
Use Cases¤
| Application | Operator | Description |
|---|---|---|
| scRNA-seq normalization | VAENormalizer | Remove technical variation |
| Visualization | DifferentiableUMAP | 2D/3D cell embeddings |
| Sequence similarity | SequenceEmbedding | Compare sequences |
| Trajectory visualization | DifferentiablePHATE | PHATE embedding |
| Feature extraction | All | Learned representations |
Next Steps¤
- See Single-Cell Operators for clustering and batch correction
- Explore Statistical Operators for differential expression
- Check VAE Loss for training objectives