Multi-omics Operators¤
DiffBio provides differentiable operators for multi-omics data analysis, including spatial transcriptomics and chromatin conformation.
Multi-omics Fully Differentiable
Overview¤
Multi-omics operators enable end-to-end optimization of:
- SpatialDeconvolution: Cell type deconvolution for spatial transcriptomics
- HiCContactAnalysis: Chromatin contact analysis for Hi-C data
- DifferentiableSpatialGeneDetector: SpatialDE-style spatial gene detection
- DifferentiableMultiOmicsVAE: Product-of-Experts multi-omics integration
Multi-omics benchmark scope¤
The current benchmark-backed multi-omics scope is seqFISH spatial deconvolution.
The diffbio.sources.seqfish.SeqFISHSource class records canonical provenance for the RNA and
spatial modalities, and the spatial deconvolution benchmark records both the
modality contract and the operator-output artifact metadata used for comparison
tracking.
Post-DTI stable boundary: benchmark-backed operator support is separate from imported foundation-model promotion.
DiffBio also provides shared indexed loaders for RNA+ATAC or other multi-omics
embedding artifacts via MultiOmicsEmbeddingSource, which aligns imported
embeddings by sample_ids. This is a shared loader and metadata contract, not stable imported multi-omics foundation-model support.
External multi-omics foundation checkpoint loading and generic tokenizer
interchange remain outside the verified scope until benchmark evidence is
attached.
SpatialDeconvolution¤
Cell type deconvolution for spatial transcriptomics data.
Quick Start¤
from flax import nnx
from diffbio.operators.multiomics import SpatialDeconvolution, SpatialDeconvolutionConfig
# Configure deconvolution
config = SpatialDeconvolutionConfig(
n_cell_types=10,
n_genes=2000,
hidden_dim=128,
temperature=1.0,
)
# Create operator
rngs = nnx.Rngs(42)
deconv = SpatialDeconvolution(config, rngs=rngs)
# Apply deconvolution
data = {
"spatial_expression": spot_expression, # (n_spots, n_genes)
"reference_profiles": cell_type_profiles, # (n_cell_types, n_genes)
}
result, state, metadata = deconv.apply(data, {}, None)
# Get cell type proportions
proportions = result["proportions"] # (n_spots, n_cell_types)
reconstructed = result["reconstructed"] # Reconstructed expression
Configuration¤
| Parameter | Type | Default | Description |
|---|---|---|---|
n_cell_types |
int | 10 | Number of cell types |
n_genes |
int | 2000 | Number of genes |
hidden_dim |
int | 128 | Neural network hidden dimension |
temperature |
float | 1.0 | Softmax temperature |
Deconvolution Model¤
The deconvolution estimates cell type proportions \(\pi\) such that:
Where:
- \(X_{spot}\) = observed expression at spot
- \(\pi_k\) = proportion of cell type \(k\)
- \(R_k\) = reference profile for cell type \(k\)
DiffBio uses a neural network to predict proportions with soft constraints:
# Soft proportion constraints (sum to 1, non-negative)
proportions = jax.nn.softmax(logits / temperature, axis=-1)
Training for Deconvolution¤
def deconv_loss(deconv, spatial_expr, reference):
"""Train deconvolution model."""
data = {
"spatial_expression": spatial_expr,
"reference_profiles": reference,
}
result, _, _ = deconv.apply(data, {}, None)
# Reconstruction loss
recon_loss = jnp.mean((result["reconstructed"] - spatial_expr) ** 2)
# Entropy regularization (encourage sparse proportions)
entropy = -jnp.sum(result["proportions"] * jnp.log(result["proportions"] + 1e-8))
return recon_loss - 0.01 * entropy
HiCContactAnalysis¤
Chromatin contact analysis for Hi-C and related 3C data.
Quick Start¤
from diffbio.operators.multiomics import HiCContactAnalysis, HiCContactAnalysisConfig
# Configure Hi-C analysis
config = HiCContactAnalysisConfig(
resolution=10000, # 10kb resolution
hidden_dim=64,
n_layers=4,
distance_decay=True,
)
# Create operator
rngs = nnx.Rngs(42)
hic_analysis = HiCContactAnalysis(config, rngs=rngs)
# Analyze contact matrix
data = {"contact_matrix": hic_matrix} # (n_bins, n_bins)
result, state, metadata = hic_analysis.apply(data, {}, None)
# Get results
normalized = result["normalized_contacts"] # Distance-normalized
compartments = result["compartments"] # A/B compartments
tad_boundaries = result["tad_boundaries"] # TAD boundary scores
loops = result["loop_scores"] # Loop/enhancer contacts
Configuration¤
| Parameter | Type | Default | Description |
|---|---|---|---|
resolution |
int | 10000 | Bin size in base pairs |
hidden_dim |
int | 64 | Network hidden dimension |
n_layers |
int | 4 | Number of network layers |
distance_decay |
bool | True | Model distance decay |
Hi-C Analysis Components¤
Distance Normalization¤
Hi-C contacts decay with genomic distance. DiffBio learns the decay function:
# Learned distance decay
expected_contacts = decay_network(distances)
normalized = observed / (expected_contacts + epsilon)
Compartment Calling¤
A/B compartments from correlation matrix:
# PCA-style compartment calling
correlation = normalize(contact_matrix)
compartment_scores = svd_projection(correlation)
compartments = jax.nn.tanh(compartment_scores) # Soft A/B
TAD Detection¤
Topologically Associating Domains from insulation scores:
# Insulation score calculation
insulation = sliding_window_mean(contacts)
boundaries = jax.nn.sigmoid(-gradient(insulation) / temperature)
Training Hi-C Models¤
def hic_loss(hic_model, contact_matrix, known_loops):
"""Train Hi-C analysis model."""
data = {"contact_matrix": contact_matrix}
result, _, _ = hic_model.apply(data, {}, None)
# Loop detection loss
loop_loss = binary_cross_entropy(result["loop_scores"], known_loops)
# TAD boundary loss (if labeled)
# tad_loss = ...
return loop_loss
DifferentiableSpatialGeneDetector¤
SpatialDE-style differentiable spatial gene detection using Gaussian processes. Identifies spatially variable genes by decomposing expression variability into spatial and non-spatial components.
Quick Start¤
from flax import nnx
from diffbio.operators.multiomics import (
DifferentiableSpatialGeneDetector,
SpatialGeneDetectorConfig,
)
# Configure detector
config = SpatialGeneDetectorConfig(
n_genes=2000,
lengthscale=1.0, # RBF kernel lengthscale
variance=1.0, # Signal variance
noise_variance=0.1, # Noise variance
hidden_dims=[64, 32], # Smoothing network layers
temperature=1.0, # Classification temperature
pvalue_threshold=0.05, # Spatial gene threshold
)
# Create operator
rngs = nnx.Rngs(42)
detector = DifferentiableSpatialGeneDetector(config, rngs=rngs)
# Apply spatial gene detection
data = {
"spatial_coords": coords, # (n_spots, 2) - Spatial coordinates
"expression": expression, # (n_spots, n_genes) - Gene expression
"total_counts": total_counts, # (n_spots,) - Optional normalization
}
result, state, metadata = detector.apply(data, {}, None)
# Get results
fsv = result["fsv"] # Fraction of Spatial Variance
spatial_pvalues = result["spatial_pvalues"] # P-values for spatial patterns
is_spatial = result["is_spatial"] # Soft spatial gene indicator
smoothed = result["smoothed_expression"] # GP-smoothed expression
Configuration¤
| Parameter | Type | Default | Description |
|---|---|---|---|
n_genes |
int | 2000 | Number of genes to analyze |
lengthscale |
float | 1.0 | RBF kernel lengthscale (spatial range) |
variance |
float | 1.0 | Signal variance (σ²_s) |
noise_variance |
float | 0.1 | Noise variance (σ²_e) |
hidden_dims |
list[int] | [64, 32] | Smoothing network dimensions |
temperature |
float | 1.0 | Temperature for soft thresholding |
pvalue_threshold |
float | 0.05 | Threshold for spatial classification |
learnable_kernel |
bool | True | Whether kernel params are learnable |
Spatial Variance Model¤
The model decomposes gene expression as:
Where:
- \(f(x) \sim \mathcal{GP}(0, K)\) is the spatial component with RBF kernel
- \(\epsilon \sim \mathcal{N}(0, \sigma^2_e)\) is the non-spatial noise
The Fraction of Spatial Variance (FSV) quantifies spatial structure:
RBF Kernel¤
The squared exponential (RBF) kernel models spatial covariance:
Where:
- \(\sigma^2_s\) = signal variance (spatial component strength)
- \(\ell\) = lengthscale (characteristic spatial range)
Training for Spatial Detection¤
def spatial_loss(detector, data):
"""Train spatial gene detector."""
result, _, _ = detector.apply(data, {}, None)
# Maximize spatial variance detection
fsv_loss = -result["fsv"].mean()
# Smoothing quality (reconstruction)
smooth_loss = jnp.mean((result["smoothed_expression"] - data["expression"]) ** 2)
return fsv_loss + 0.1 * smooth_loss
Interpreting Results¤
# Identify spatially variable genes
spatial_genes = result["is_spatial"] > 0.5
n_spatial = spatial_genes.sum()
# Get top spatial genes by FSV
top_spatial_idx = jnp.argsort(result["fsv"])[::-1][:100]
# Visualize smoothed expression
import matplotlib.pyplot as plt
gene_idx = top_spatial_idx[0]
plt.scatter(
data["spatial_coords"][:, 0],
data["spatial_coords"][:, 1],
c=result["smoothed_expression"][:, gene_idx],
)
DifferentiableMultiOmicsVAE¤
Multi-omics VAE with Product-of-Experts (PoE) latent fusion, inspired by MULTIVI. Jointly integrates data from multiple modalities (RNA-seq, ATAC-seq, protein, etc.) into a shared latent space via per-modality encoders whose posteriors are fused through PoE.
Quick Start¤
from diffbio.operators.multiomics import DifferentiableMultiOmicsVAE, MultiOmicsVAEConfig
config = MultiOmicsVAEConfig(
modality_dims=[2000, 500], # RNA-seq (2000 genes), ATAC-seq (500 peaks)
latent_dim=10,
hidden_dim=64,
modality_weight_mode="equal", # or "learnable"
)
rngs = nnx.Rngs(42)
vae = DifferentiableMultiOmicsVAE(config, rngs=rngs)
data = {
"rna_counts": rna_expression, # (n_cells, 2000)
"atac_counts": atac_counts, # (n_cells, 500)
}
result, state, metadata = vae.apply(data, {}, None)
latent = result["joint_latent"] # (n_cells, latent_dim)
rna_recon = result["rna_reconstructed"] # (n_cells, 2000)
atac_recon = result["atac_reconstructed"] # (n_cells, 500)
elbo = result["elbo_loss"] # scalar
Configuration¤
| Parameter | Type | Default | Description |
|---|---|---|---|
modality_dims |
list[int] | [2000, 500] | Feature dimension per modality |
latent_dim |
int | 10 | Shared latent space dimension |
hidden_dim |
int | 64 | Encoder/decoder hidden layer width |
modality_weight_mode |
str | "equal" | "equal" or "learnable" reconstruction weights |
use_gradnorm |
bool | False | Use GradNormBalancer for multi-task loss |
Product-of-Experts Fusion¤
For \(M\) modalities, the PoE joint posterior is Gaussian with:
where \(\text{precision}_m = \exp(-\text{logvar}_m)\).
Data Key Convention¤
With two modalities the canonical keys rna and atac are used automatically. For other modality counts, keys follow the pattern modality_<i>_counts / modality_<i>_reconstructed.
Multi-omics Integration¤
Combine multiple data modalities:
from diffbio.operators.multiomics import SpatialDeconvolution
from diffbio.operators.epigenomics import ChromatinStateAnnotator
# Spatial transcriptomics + ATAC-seq
spatial_expr = ... # Gene expression per spot
atac_signal = ... # Chromatin accessibility
# Deconvolve cell types
deconv_result, _, _ = spatial_deconv.apply(
{"spatial_expression": spatial_expr, "reference_profiles": ref},
{}, None
)
# Annotate chromatin states
chrom_result, _, _ = chromatin_annotator.apply(
{"histone_marks": atac_signal},
{}, None
)
# Combine for multi-modal analysis
combined_features = jnp.concatenate([
deconv_result["proportions"],
chrom_result["state_probabilities"],
], axis=-1)
Use Cases¤
| Application | Operator | Description |
|---|---|---|
| Cell type mapping | SpatialDeconvolution | Spatial transcriptomics |
| Tissue architecture | SpatialDeconvolution | Understand tissue structure |
| Chromatin structure | HiCContactAnalysis | 3D genome organization |
| Enhancer-promoter | HiCContactAnalysis | Find regulatory contacts |
| TAD analysis | HiCContactAnalysis | Domain boundaries |
| Spatial gene detection | DifferentiableSpatialGeneDetector | Find spatially variable genes |
| Spatial patterns | DifferentiableSpatialGeneDetector | Identify spatial expression patterns |
| Multi-omics integration | DifferentiableMultiOmicsVAE | Integrate RNA + ATAC/protein |
| Shared latent space | DifferentiableMultiOmicsVAE | Joint embedding across modalities |
Next Steps¤
- See Epigenomics Operators for chromatin analysis
- Explore Single-Cell Operators for reference profiles