Multi-omics Operators API¤
Differentiable operators for multi-omics analysis including spatial transcriptomics, Hi-C, and spatial gene detection.
SpatialDeconvolution¤
diffbio.operators.multiomics.spatial_deconvolution.SpatialDeconvolution
¤
SpatialDeconvolution(
config: SpatialDeconvolutionConfig,
*,
rngs: Rngs | None = None,
name: str | None = None,
)
Bases: TemperatureOperator
Differentiable spatial transcriptomics deconvolution.
This operator performs cell type deconvolution of spatial transcriptomics spots using reference single-cell profiles. It incorporates spatial context through coordinate embeddings.
Algorithm: 1. Encode spot expression profiles 2. Encode spatial coordinates 3. Combine expression and spatial features 4. Compute attention to reference cell type profiles 5. Apply softmax for cell type proportions 6. Reconstruct expression from proportions
Inherits from TemperatureOperator to get:
- _temperature property for temperature-controlled smoothing
- soft_max() for logsumexp-based smooth maximum
- soft_argmax() for soft position selection
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
SpatialDeconvolutionConfig
|
SpatialDeconvolutionConfig with model parameters. |
required |
rngs
|
Rngs | None
|
Flax NNX random number generators. |
None
|
name
|
str | None
|
Optional operator name. |
None
|
Example
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
SpatialDeconvolutionConfig
|
Deconvolution configuration. |
required |
rngs
|
Rngs | None
|
Random number generators for initialization. |
None
|
name
|
str | None
|
Optional operator name. |
None
|
apply
¤
apply(
data: PyTree,
state: PyTree,
metadata: dict[str, Any] | None,
random_params: Any = None,
stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]
Apply spatial deconvolution.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
PyTree
|
Dictionary containing: - "spot_expression": Spot expression (n_spots, n_genes) - "reference_profiles": Reference profiles (n_cell_types, n_genes) - "coordinates": Spot coordinates (n_spots, 2) |
required |
state
|
PyTree
|
Element state (passed through unchanged) |
required |
metadata
|
dict[str, Any] | None
|
Element metadata (passed through unchanged) |
required |
random_params
|
Any
|
Not used |
None
|
stats
|
dict[str, Any] | None
|
Not used |
None
|
Returns:
| Type | Description |
|---|---|
tuple[PyTree, PyTree, dict[str, Any] | None]
|
Tuple of (transformed_data, state, metadata): - transformed_data contains:
|
SpatialDeconvolutionConfig¤
diffbio.operators.multiomics.spatial_deconvolution.SpatialDeconvolutionConfig
dataclass
¤
SpatialDeconvolutionConfig(
n_genes: int = 2000,
n_cell_types: int = 10,
hidden_dim: int = 128,
num_layers: int = 2,
spatial_hidden: int = 32,
dropout_rate: float = 0.1,
temperature: float = 1.0,
)
Bases: OperatorConfig
Configuration for SpatialDeconvolution.
Attributes:
| Name | Type | Description |
|---|---|---|
n_genes |
int
|
Number of genes in expression profiles. |
n_cell_types |
int
|
Number of reference cell types. |
hidden_dim |
int
|
Hidden dimension for neural networks. |
num_layers |
int
|
Number of encoder layers. |
spatial_hidden |
int
|
Hidden dimension for spatial encoder. |
dropout_rate |
float
|
Dropout rate for regularization. |
temperature |
float
|
Temperature for softmax operations. |
HiCContactAnalysis¤
diffbio.operators.multiomics.hic_contact.HiCContactAnalysis
¤
HiCContactAnalysis(
config: HiCContactAnalysisConfig,
*,
rngs: Rngs | None = None,
name: str | None = None,
)
Bases: TemperatureOperator
Differentiable Hi-C contact analysis.
This operator analyzes Hi-C contact matrices to identify chromatin compartments and TAD boundaries using neural networks.
Algorithm: 1. Encode contact patterns per bin 2. Encode genomic bin features 3. Combine contact and feature embeddings 4. Apply attention for context 5. Predict compartment scores 6. Detect TAD boundaries 7. Reconstruct contacts from embeddings
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
HiCContactAnalysisConfig
|
HiCContactAnalysisConfig with model parameters. |
required |
rngs
|
Rngs | None
|
Flax NNX random number generators. |
None
|
name
|
str | None
|
Optional operator name. |
None
|
Example
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
HiCContactAnalysisConfig
|
Analysis configuration. |
required |
rngs
|
Rngs | None
|
Random number generators for initialization. |
None
|
name
|
str | None
|
Optional operator name. |
None
|
apply
¤
apply(
data: PyTree,
state: PyTree,
metadata: dict[str, Any] | None,
random_params: Any = None,
stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]
Apply Hi-C contact analysis.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
PyTree
|
Dictionary containing: - "contact_matrix": Hi-C contact matrix (n_bins, n_bins) - "bin_features": Bin genomic features (n_bins, bin_features) |
required |
state
|
PyTree
|
Element state (passed through unchanged) |
required |
metadata
|
dict[str, Any] | None
|
Element metadata (passed through unchanged) |
required |
random_params
|
Any
|
Not used |
None
|
stats
|
dict[str, Any] | None
|
Not used |
None
|
Returns:
| Type | Description |
|---|---|
tuple[PyTree, PyTree, dict[str, Any] | None]
|
Tuple of (transformed_data, state, metadata): - transformed_data contains:
|
HiCContactAnalysisConfig¤
diffbio.operators.multiomics.hic_contact.HiCContactAnalysisConfig
dataclass
¤
HiCContactAnalysisConfig(
n_bins: int = 1000,
hidden_dim: int = 128,
num_layers: int = 3,
num_heads: int = 4,
bin_features: int = 16,
dropout_rate: float = 0.1,
temperature: float = 1.0,
)
Bases: OperatorConfig
Configuration for HiCContactAnalysis.
Attributes:
| Name | Type | Description |
|---|---|---|
n_bins |
int
|
Number of genomic bins. |
hidden_dim |
int
|
Hidden dimension for neural networks. |
num_layers |
int
|
Number of encoder layers. |
num_heads |
int
|
Number of attention heads. |
bin_features |
int
|
Dimension of input bin features. |
dropout_rate |
float
|
Dropout rate for regularization. |
temperature |
float
|
Temperature for softmax operations. |
DifferentiableSpatialGeneDetector¤
diffbio.operators.multiomics.spatial_gene_detection.DifferentiableSpatialGeneDetector
¤
DifferentiableSpatialGeneDetector(
config: SpatialGeneDetectorConfig,
*,
rngs: Rngs,
name: str | None = None,
)
Bases: TemperatureOperator
SpatialDE-style differentiable spatial gene detection.
This operator identifies spatially variable genes using a differentiable Gaussian process approach. It computes a spatial variance score for each gene and provides soft assignments for spatial vs non-spatial genes.
The model decomposes gene expression as
y = f(x) + epsilon
where f(x) ~ GP(0, K) is the spatial component and epsilon ~ N(0, sigma^2) is the non-spatial noise.
The Fraction of Spatial Variance (FSV) is: FSV = sigma^2_s / (sigma^2_s + sigma^2_e)
Input data structure
- spatial_coords: Float[Array, "n_spots 2"] - Spatial coordinates
- expression: Float[Array, "n_spots n_genes"] - Gene expression
- total_counts: Float[Array, "n_spots"] - Total counts per spot
Output data structure (adds): - spatial_variance: Float[Array, "n_genes"] - Spatial variance per gene - spatial_pvalues: Float[Array, "n_genes"] - P-values for spatial patterns - is_spatial: Float[Array, "n_genes"] - Soft spatial gene indicator - smoothed_expression: Float[Array, "n_spots n_genes"] - GP smoothed expression - fsv: Float[Array, "n_genes"] - Fraction of Spatial Variance
Example
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
SpatialGeneDetectorConfig
|
Detector configuration. |
required |
rngs
|
Rngs
|
Random number generators. |
required |
name
|
str | None
|
Optional name for the operator. |
None
|
apply
¤
apply(
data: dict[str, Array],
state: dict[str, Any],
metadata: dict[str, Any] | None,
random_params: Any = None,
stats: dict[str, Any] | None = None,
) -> tuple[
dict[str, Array], dict[str, Any], dict[str, Any] | None
]
Apply spatial gene detection.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
dict[str, Array]
|
Input data containing: - spatial_coords: Float[Array, "n_spots 2"] - expression: Float[Array, "n_spots n_genes"] - total_counts: Float[Array, "n_spots"] (optional) |
required |
state
|
dict[str, Any]
|
Element state (passed through). |
required |
metadata
|
dict[str, Any] | None
|
Element metadata (passed through). |
required |
Returns:
| Type | Description |
|---|---|
tuple[dict[str, Array], dict[str, Any], dict[str, Any] | None]
|
Tuple of (output_data, state, metadata). |
compute_kernel
¤
Compute squared exponential (RBF) kernel matrix.
K(x1, x2) = variance * exp(-||x1 - x2||^2 / (2 * lengthscale^2))
This is the standard kernel used in SpatialDE for modeling spatial covariance.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
X1
|
Float[Array, 'n1 2']
|
First set of spatial coordinates. |
required |
X2
|
Float[Array, 'n2 2']
|
Second set of spatial coordinates. |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, 'n1 n2']
|
Kernel matrix. |
compute_spatial_variance
¤
compute_spatial_variance(
coords: Float[Array, "n_spots 2"],
expression: Float[Array, "n_spots n_genes"],
) -> tuple[Float[Array, n_genes], Float[Array, n_genes]]
Compute spatial variance and FSV for each gene.
Uses neural network approximation to GP posterior mean for efficiency. Computes variance decomposition: total = spatial + residual.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
coords
|
Float[Array, 'n_spots 2']
|
Spatial coordinates. |
required |
expression
|
Float[Array, 'n_spots n_genes']
|
Normalized gene expression. |
required |
Returns:
| Type | Description |
|---|---|
tuple[Float[Array, n_genes], Float[Array, n_genes]]
|
Tuple of (spatial_variance, fsv) per gene. |
compute_pvalues
¤
Compute differentiable pseudo-p-values for spatial patterns.
Uses a soft approximation to the likelihood ratio test. In SpatialDE, p-values come from comparing the spatial model to a null model without spatial structure.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fsv
|
Float[Array, n_genes]
|
Fraction of Spatial Variance per gene. |
required |
n_spots
|
int
|
Number of spatial locations. |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, n_genes]
|
Soft p-values (lower = more spatially variable). |
SpatialGeneDetectorConfig¤
diffbio.operators.multiomics.spatial_gene_detection.SpatialGeneDetectorConfig
dataclass
¤
SpatialGeneDetectorConfig(
n_genes: int = 2000,
hidden_dims: tuple[int, ...] | list[int] = (64, 32),
temperature: float = 1.0,
pvalue_threshold: float = 0.05,
compute_field_ops: bool = False,
lengthscale: float = 1.0,
variance: float = 1.0,
noise_variance: float = 0.1,
n_inducing_points: int = 100,
learnable_kernel: bool = True,
)
Bases: _SpatialKernelConfig, _SpatialDetectionConfig, OperatorConfig
Configuration for spatial gene detection.
hidden_dims
class-attribute
instance-attribute
¤
create_spatial_gene_detector¤
diffbio.operators.multiomics.spatial_gene_detection.create_spatial_gene_detector
¤
create_spatial_gene_detector(
n_genes: int = 2000,
n_inducing_points: int = 100,
lengthscale: float = 1.0,
variance: float = 1.0,
seed: int = 42,
) -> DifferentiableSpatialGeneDetector
Factory function to create a spatial gene detector.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_genes
|
int
|
Number of genes to analyze. |
2000
|
n_inducing_points
|
int
|
Number of inducing points for sparse GP. |
100
|
lengthscale
|
float
|
Initial kernel lengthscale. |
1.0
|
variance
|
float
|
Initial signal variance. |
1.0
|
seed
|
int
|
Random seed. |
42
|
Returns:
| Type | Description |
|---|---|
DifferentiableSpatialGeneDetector
|
Configured DifferentiableSpatialGeneDetector instance. |
DifferentiableMultiOmicsVAE¤
diffbio.operators.multiomics.multiomics_vae.DifferentiableMultiOmicsVAE
¤
DifferentiableMultiOmicsVAE(
config: MultiOmicsVAEConfig,
*,
rngs: Rngs | None = None,
name: str | None = None,
)
Bases: LossBalancingMixin, EncoderDecoderOperator
Multi-omics VAE with Product-of-Experts latent fusion.
For each modality a dedicated encoder produces (mu_m, logvar_m). These are combined via PoE into a joint posterior from which z is sampled. Per-modality decoders then reconstruct counts from z.
The ELBO objective uses MSE reconstruction loss per modality, optionally weighted by learnable per-modality weights, plus a KL divergence term against a standard-normal prior.
Data keys follow the convention <name>_counts for input and
<name>_reconstructed for output. When exactly two modalities
are used the canonical names rna and atac are applied;
otherwise modality_<i> is used.
Attributes:
| Name | Type | Description |
|---|---|---|
encoders |
Per-modality encoder modules. |
|
decoders |
Per-modality decoder modules. |
|
mu_heads |
Per-modality linear projection for latent mean. |
|
logvar_heads |
Per-modality linear projection for latent logvar. |
|
log_modality_weights |
Learnable log-weights (only in 'learnable' mode). |
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
MultiOmicsVAEConfig
|
Operator configuration. |
required |
rngs
|
Rngs | None
|
Flax NNX random number generators. |
None
|
name
|
str | None
|
Optional operator name. |
None
|
apply
¤
apply(
data: PyTree,
state: PyTree,
metadata: dict[str, Any] | None,
random_params: Any = None,
stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]
Run the multi-omics VAE forward pass.
Steps
- Encode each modality to (mu_m, logvar_m).
- PoE fusion -> (mu_joint, logvar_joint).
- Reparameterise -> z.
- Decode each modality from z.
- Compute ELBO = weighted recon + KL.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
PyTree
|
Dictionary with |
required |
state
|
PyTree
|
Operator state (passed through unchanged). |
required |
metadata
|
dict[str, Any] | None
|
Operator metadata (passed through unchanged). |
required |
random_params
|
Any
|
Not used. |
None
|
stats
|
dict[str, Any] | None
|
Not used. |
None
|
Returns:
| Type | Description |
|---|---|
PyTree
|
Tuple of (result_data, state, metadata) where result_data |
PyTree
|
contains the original inputs plus |
dict[str, Any] | None
|
|
MultiOmicsVAEConfig¤
diffbio.operators.multiomics.multiomics_vae.MultiOmicsVAEConfig
dataclass
¤
MultiOmicsVAEConfig(
modality_dims: list[int] = (lambda: [2000, 500])(),
latent_dim: int = 10,
hidden_dim: int = 64,
modality_weight_mode: str = "equal",
use_gradnorm: bool = False,
)
Bases: OperatorConfig
Configuration for DifferentiableMultiOmicsVAE.
Attributes:
| Name | Type | Description |
|---|---|---|
modality_dims |
list[int]
|
Feature dimension for each modality. |
latent_dim |
int
|
Shared latent space dimension. |
hidden_dim |
int
|
Hidden layer width for all encoders / decoders. |
modality_weight_mode |
str
|
How reconstruction losses are weighted. 'equal' gives uniform weight; 'learnable' uses softmax over a learnable log-weight vector. |
Usage Examples¤
Spatial Deconvolution¤
from flax import nnx
from diffbio.operators.multiomics import SpatialDeconvolution, SpatialDeconvolutionConfig
config = SpatialDeconvolutionConfig(n_cell_types=10, n_genes=2000)
deconv = SpatialDeconvolution(config, rngs=nnx.Rngs(42))
data = {
"spatial_expression": spot_expression, # (n_spots, n_genes)
"reference_profiles": cell_type_profiles, # (n_cell_types, n_genes)
}
result, _, _ = deconv.apply(data, {}, None)
proportions = result["proportions"]
Hi-C Contact Analysis¤
from diffbio.operators.multiomics import HiCContactAnalysis, HiCContactAnalysisConfig
config = HiCContactAnalysisConfig(n_bins=1000, hidden_dim=64)
hic_analysis = HiCContactAnalysis(config, rngs=nnx.Rngs(42))
data = {"contact_matrix": hic_matrix} # (n_bins, n_bins)
result, _, _ = hic_analysis.apply(data, {}, None)
compartments = result["compartments"]
tad_boundaries = result["tad_boundaries"]
Spatial Gene Detection¤
from flax import nnx
from diffbio.operators.multiomics import (
DifferentiableSpatialGeneDetector,
SpatialGeneDetectorConfig,
create_spatial_gene_detector,
)
# Using config
config = SpatialGeneDetectorConfig(
n_genes=2000,
lengthscale=1.0,
variance=1.0,
pvalue_threshold=0.05,
)
detector = DifferentiableSpatialGeneDetector(config, rngs=nnx.Rngs(42))
# Or using factory function
detector = create_spatial_gene_detector(
n_genes=2000,
lengthscale=1.0,
)
# Apply spatial gene detection
data = {
"spatial_coords": coords, # (n_spots, 2)
"expression": expression, # (n_spots, n_genes)
"total_counts": total_counts, # (n_spots,) optional
}
result, _, _ = detector.apply(data, {}, None)
# Get spatial gene results
fsv = result["fsv"] # Fraction of Spatial Variance
is_spatial = result["is_spatial"] # Soft spatial indicator
smoothed = result["smoothed_expression"] # GP-smoothed expression