Skip to content

Multi-omics Operators API¤

Differentiable operators for multi-omics analysis including spatial transcriptomics, Hi-C, and spatial gene detection.

SpatialDeconvolution¤

diffbio.operators.multiomics.spatial_deconvolution.SpatialDeconvolution ¤

SpatialDeconvolution(
    config: SpatialDeconvolutionConfig,
    *,
    rngs: Rngs | None = None,
    name: str | None = None,
)

Bases: TemperatureOperator

Differentiable spatial transcriptomics deconvolution.

This operator performs cell type deconvolution of spatial transcriptomics spots using reference single-cell profiles. It incorporates spatial context through coordinate embeddings.

Algorithm: 1. Encode spot expression profiles 2. Encode spatial coordinates 3. Combine expression and spatial features 4. Compute attention to reference cell type profiles 5. Apply softmax for cell type proportions 6. Reconstruct expression from proportions

Inherits from TemperatureOperator to get:

  • _temperature property for temperature-controlled smoothing
  • soft_max() for logsumexp-based smooth maximum
  • soft_argmax() for soft position selection

Parameters:

Name Type Description Default
config SpatialDeconvolutionConfig

SpatialDeconvolutionConfig with model parameters.

required
rngs Rngs | None

Flax NNX random number generators.

None
name str | None

Optional operator name.

None
Example
config = SpatialDeconvolutionConfig(n_cell_types=10)
deconv = SpatialDeconvolution(config, rngs=nnx.Rngs(42))
data = {"spot_expression": spots, "reference_profiles": refs, "coordinates": coords}
result, state, meta = deconv.apply(data, {}, None)

Parameters:

Name Type Description Default
config SpatialDeconvolutionConfig

Deconvolution configuration.

required
rngs Rngs | None

Random number generators for initialization.

None
name str | None

Optional operator name.

None

apply ¤

apply(
    data: PyTree,
    state: PyTree,
    metadata: dict[str, Any] | None,
    random_params: Any = None,
    stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]

Apply spatial deconvolution.

Parameters:

Name Type Description Default
data PyTree

Dictionary containing: - "spot_expression": Spot expression (n_spots, n_genes) - "reference_profiles": Reference profiles (n_cell_types, n_genes) - "coordinates": Spot coordinates (n_spots, 2)

required
state PyTree

Element state (passed through unchanged)

required
metadata dict[str, Any] | None

Element metadata (passed through unchanged)

required
random_params Any

Not used

None
stats dict[str, Any] | None

Not used

None

Returns:

Type Description
tuple[PyTree, PyTree, dict[str, Any] | None]

Tuple of (transformed_data, state, metadata): - transformed_data contains:

- "spot_expression": Original expression
- "reference_profiles": Original references
- "coordinates": Original coordinates
- "cell_proportions": Deconvolved proportions
- "reconstructed_expression": Reconstructed expression
- "spatial_embeddings": Spatial feature embeddings
  • state is passed through unchanged
  • metadata is passed through unchanged

SpatialDeconvolutionConfig¤

diffbio.operators.multiomics.spatial_deconvolution.SpatialDeconvolutionConfig dataclass ¤

SpatialDeconvolutionConfig(
    n_genes: int = 2000,
    n_cell_types: int = 10,
    hidden_dim: int = 128,
    num_layers: int = 2,
    spatial_hidden: int = 32,
    dropout_rate: float = 0.1,
    temperature: float = 1.0,
)

Bases: OperatorConfig

Configuration for SpatialDeconvolution.

Attributes:

Name Type Description
n_genes int

Number of genes in expression profiles.

n_cell_types int

Number of reference cell types.

hidden_dim int

Hidden dimension for neural networks.

num_layers int

Number of encoder layers.

spatial_hidden int

Hidden dimension for spatial encoder.

dropout_rate float

Dropout rate for regularization.

temperature float

Temperature for softmax operations.

HiCContactAnalysis¤

diffbio.operators.multiomics.hic_contact.HiCContactAnalysis ¤

HiCContactAnalysis(
    config: HiCContactAnalysisConfig,
    *,
    rngs: Rngs | None = None,
    name: str | None = None,
)

Bases: TemperatureOperator

Differentiable Hi-C contact analysis.

This operator analyzes Hi-C contact matrices to identify chromatin compartments and TAD boundaries using neural networks.

Algorithm: 1. Encode contact patterns per bin 2. Encode genomic bin features 3. Combine contact and feature embeddings 4. Apply attention for context 5. Predict compartment scores 6. Detect TAD boundaries 7. Reconstruct contacts from embeddings

Parameters:

Name Type Description Default
config HiCContactAnalysisConfig

HiCContactAnalysisConfig with model parameters.

required
rngs Rngs | None

Flax NNX random number generators.

None
name str | None

Optional operator name.

None
Example
config = HiCContactAnalysisConfig(n_bins=1000)
analyzer = HiCContactAnalysis(config, rngs=nnx.Rngs(42))
data = {"contact_matrix": contacts, "bin_features": features}
result, state, meta = analyzer.apply(data, {}, None)

Parameters:

Name Type Description Default
config HiCContactAnalysisConfig

Analysis configuration.

required
rngs Rngs | None

Random number generators for initialization.

None
name str | None

Optional operator name.

None

apply ¤

apply(
    data: PyTree,
    state: PyTree,
    metadata: dict[str, Any] | None,
    random_params: Any = None,
    stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]

Apply Hi-C contact analysis.

Parameters:

Name Type Description Default
data PyTree

Dictionary containing: - "contact_matrix": Hi-C contact matrix (n_bins, n_bins) - "bin_features": Bin genomic features (n_bins, bin_features)

required
state PyTree

Element state (passed through unchanged)

required
metadata dict[str, Any] | None

Element metadata (passed through unchanged)

required
random_params Any

Not used

None
stats dict[str, Any] | None

Not used

None

Returns:

Type Description
tuple[PyTree, PyTree, dict[str, Any] | None]

Tuple of (transformed_data, state, metadata): - transformed_data contains:

- "contact_matrix": Original contacts
- "bin_features": Original features
- "bin_embeddings": Learned bin embeddings
- "compartment_scores": A/B compartment scores
- "tad_boundary_scores": TAD boundary probabilities
- "predicted_contacts": Reconstructed contacts
  • state is passed through unchanged
  • metadata is passed through unchanged

HiCContactAnalysisConfig¤

diffbio.operators.multiomics.hic_contact.HiCContactAnalysisConfig dataclass ¤

HiCContactAnalysisConfig(
    n_bins: int = 1000,
    hidden_dim: int = 128,
    num_layers: int = 3,
    num_heads: int = 4,
    bin_features: int = 16,
    dropout_rate: float = 0.1,
    temperature: float = 1.0,
)

Bases: OperatorConfig

Configuration for HiCContactAnalysis.

Attributes:

Name Type Description
n_bins int

Number of genomic bins.

hidden_dim int

Hidden dimension for neural networks.

num_layers int

Number of encoder layers.

num_heads int

Number of attention heads.

bin_features int

Dimension of input bin features.

dropout_rate float

Dropout rate for regularization.

temperature float

Temperature for softmax operations.

DifferentiableSpatialGeneDetector¤

diffbio.operators.multiomics.spatial_gene_detection.DifferentiableSpatialGeneDetector ¤

DifferentiableSpatialGeneDetector(
    config: SpatialGeneDetectorConfig,
    *,
    rngs: Rngs,
    name: str | None = None,
)

Bases: TemperatureOperator

SpatialDE-style differentiable spatial gene detection.

This operator identifies spatially variable genes using a differentiable Gaussian process approach. It computes a spatial variance score for each gene and provides soft assignments for spatial vs non-spatial genes.

The model decomposes gene expression as

y = f(x) + epsilon

where f(x) ~ GP(0, K) is the spatial component and epsilon ~ N(0, sigma^2) is the non-spatial noise.

The Fraction of Spatial Variance (FSV) is: FSV = sigma^2_s / (sigma^2_s + sigma^2_e)

Input data structure
  • spatial_coords: Float[Array, "n_spots 2"] - Spatial coordinates
  • expression: Float[Array, "n_spots n_genes"] - Gene expression
  • total_counts: Float[Array, "n_spots"] - Total counts per spot

Output data structure (adds): - spatial_variance: Float[Array, "n_genes"] - Spatial variance per gene - spatial_pvalues: Float[Array, "n_genes"] - P-values for spatial patterns - is_spatial: Float[Array, "n_genes"] - Soft spatial gene indicator - smoothed_expression: Float[Array, "n_spots n_genes"] - GP smoothed expression - fsv: Float[Array, "n_genes"] - Fraction of Spatial Variance

Example
config = SpatialGeneDetectorConfig(n_genes=2000)
detector = DifferentiableSpatialGeneDetector(config, rngs=nnx.Rngs(42))
result, state, meta = detector.apply(data, {}, None)
spatial_genes = result["is_spatial"] > 0.5

Parameters:

Name Type Description Default
config SpatialGeneDetectorConfig

Detector configuration.

required
rngs Rngs

Random number generators.

required
name str | None

Optional name for the operator.

None

apply ¤

apply(
    data: dict[str, Array],
    state: dict[str, Any],
    metadata: dict[str, Any] | None,
    random_params: Any = None,
    stats: dict[str, Any] | None = None,
) -> tuple[
    dict[str, Array], dict[str, Any], dict[str, Any] | None
]

Apply spatial gene detection.

Parameters:

Name Type Description Default
data dict[str, Array]

Input data containing: - spatial_coords: Float[Array, "n_spots 2"] - expression: Float[Array, "n_spots n_genes"] - total_counts: Float[Array, "n_spots"] (optional)

required
state dict[str, Any]

Element state (passed through).

required
metadata dict[str, Any] | None

Element metadata (passed through).

required

Returns:

Type Description
tuple[dict[str, Array], dict[str, Any], dict[str, Any] | None]

Tuple of (output_data, state, metadata).

compute_kernel ¤

compute_kernel(
    X1: Float[Array, "n1 2"], X2: Float[Array, "n2 2"]
) -> Float[Array, "n1 n2"]

Compute squared exponential (RBF) kernel matrix.

K(x1, x2) = variance * exp(-||x1 - x2||^2 / (2 * lengthscale^2))

This is the standard kernel used in SpatialDE for modeling spatial covariance.

Parameters:

Name Type Description Default
X1 Float[Array, 'n1 2']

First set of spatial coordinates.

required
X2 Float[Array, 'n2 2']

Second set of spatial coordinates.

required

Returns:

Type Description
Float[Array, 'n1 n2']

Kernel matrix.

compute_spatial_variance ¤

compute_spatial_variance(
    coords: Float[Array, "n_spots 2"],
    expression: Float[Array, "n_spots n_genes"],
) -> tuple[Float[Array, n_genes], Float[Array, n_genes]]

Compute spatial variance and FSV for each gene.

Uses neural network approximation to GP posterior mean for efficiency. Computes variance decomposition: total = spatial + residual.

Parameters:

Name Type Description Default
coords Float[Array, 'n_spots 2']

Spatial coordinates.

required
expression Float[Array, 'n_spots n_genes']

Normalized gene expression.

required

Returns:

Type Description
tuple[Float[Array, n_genes], Float[Array, n_genes]]

Tuple of (spatial_variance, fsv) per gene.

compute_pvalues ¤

compute_pvalues(
    fsv: Float[Array, n_genes], n_spots: int
) -> Float[Array, n_genes]

Compute differentiable pseudo-p-values for spatial patterns.

Uses a soft approximation to the likelihood ratio test. In SpatialDE, p-values come from comparing the spatial model to a null model without spatial structure.

Parameters:

Name Type Description Default
fsv Float[Array, n_genes]

Fraction of Spatial Variance per gene.

required
n_spots int

Number of spatial locations.

required

Returns:

Type Description
Float[Array, n_genes]

Soft p-values (lower = more spatially variable).

SpatialGeneDetectorConfig¤

diffbio.operators.multiomics.spatial_gene_detection.SpatialGeneDetectorConfig dataclass ¤

SpatialGeneDetectorConfig(
    n_genes: int = 2000,
    hidden_dims: tuple[int, ...] | list[int] = (64, 32),
    temperature: float = 1.0,
    pvalue_threshold: float = 0.05,
    compute_field_ops: bool = False,
    lengthscale: float = 1.0,
    variance: float = 1.0,
    noise_variance: float = 0.1,
    n_inducing_points: int = 100,
    learnable_kernel: bool = True,
)

Bases: _SpatialKernelConfig, _SpatialDetectionConfig, OperatorConfig

Configuration for spatial gene detection.

n_genes class-attribute instance-attribute ¤

n_genes: int = 2000

hidden_dims class-attribute instance-attribute ¤

hidden_dims: tuple[int, ...] | list[int] = (64, 32)

temperature class-attribute instance-attribute ¤

temperature: float = 1.0

pvalue_threshold class-attribute instance-attribute ¤

pvalue_threshold: float = 0.05

compute_field_ops class-attribute instance-attribute ¤

compute_field_ops: bool = False

lengthscale class-attribute instance-attribute ¤

lengthscale: float = 1.0

variance class-attribute instance-attribute ¤

variance: float = 1.0

noise_variance class-attribute instance-attribute ¤

noise_variance: float = 0.1

n_inducing_points class-attribute instance-attribute ¤

n_inducing_points: int = 100

learnable_kernel class-attribute instance-attribute ¤

learnable_kernel: bool = True

create_spatial_gene_detector¤

diffbio.operators.multiomics.spatial_gene_detection.create_spatial_gene_detector ¤

create_spatial_gene_detector(
    n_genes: int = 2000,
    n_inducing_points: int = 100,
    lengthscale: float = 1.0,
    variance: float = 1.0,
    seed: int = 42,
) -> DifferentiableSpatialGeneDetector

Factory function to create a spatial gene detector.

Parameters:

Name Type Description Default
n_genes int

Number of genes to analyze.

2000
n_inducing_points int

Number of inducing points for sparse GP.

100
lengthscale float

Initial kernel lengthscale.

1.0
variance float

Initial signal variance.

1.0
seed int

Random seed.

42

Returns:

Type Description
DifferentiableSpatialGeneDetector

Configured DifferentiableSpatialGeneDetector instance.

DifferentiableMultiOmicsVAE¤

diffbio.operators.multiomics.multiomics_vae.DifferentiableMultiOmicsVAE ¤

DifferentiableMultiOmicsVAE(
    config: MultiOmicsVAEConfig,
    *,
    rngs: Rngs | None = None,
    name: str | None = None,
)

Bases: LossBalancingMixin, EncoderDecoderOperator

Multi-omics VAE with Product-of-Experts latent fusion.

For each modality a dedicated encoder produces (mu_m, logvar_m). These are combined via PoE into a joint posterior from which z is sampled. Per-modality decoders then reconstruct counts from z.

The ELBO objective uses MSE reconstruction loss per modality, optionally weighted by learnable per-modality weights, plus a KL divergence term against a standard-normal prior.

Data keys follow the convention <name>_counts for input and <name>_reconstructed for output. When exactly two modalities are used the canonical names rna and atac are applied; otherwise modality_<i> is used.

Attributes:

Name Type Description
encoders

Per-modality encoder modules.

decoders

Per-modality decoder modules.

mu_heads

Per-modality linear projection for latent mean.

logvar_heads

Per-modality linear projection for latent logvar.

log_modality_weights

Learnable log-weights (only in 'learnable' mode).

Parameters:

Name Type Description Default
config MultiOmicsVAEConfig

Operator configuration.

required
rngs Rngs | None

Flax NNX random number generators.

None
name str | None

Optional operator name.

None

apply ¤

apply(
    data: PyTree,
    state: PyTree,
    metadata: dict[str, Any] | None,
    random_params: Any = None,
    stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]

Run the multi-omics VAE forward pass.

Steps
  1. Encode each modality to (mu_m, logvar_m).
  2. PoE fusion -> (mu_joint, logvar_joint).
  3. Reparameterise -> z.
  4. Decode each modality from z.
  5. Compute ELBO = weighted recon + KL.

Parameters:

Name Type Description Default
data PyTree

Dictionary with <modality>_counts keys, each of shape (n_cells, modality_dim).

required
state PyTree

Operator state (passed through unchanged).

required
metadata dict[str, Any] | None

Operator metadata (passed through unchanged).

required
random_params Any

Not used.

None
stats dict[str, Any] | None

Not used.

None

Returns:

Type Description
PyTree

Tuple of (result_data, state, metadata) where result_data

PyTree

contains the original inputs plus joint_latent,

dict[str, Any] | None

<modality>_reconstructed, and elbo_loss.

MultiOmicsVAEConfig¤

diffbio.operators.multiomics.multiomics_vae.MultiOmicsVAEConfig dataclass ¤

MultiOmicsVAEConfig(
    modality_dims: list[int] = (lambda: [2000, 500])(),
    latent_dim: int = 10,
    hidden_dim: int = 64,
    modality_weight_mode: str = "equal",
    use_gradnorm: bool = False,
)

Bases: OperatorConfig

Configuration for DifferentiableMultiOmicsVAE.

Attributes:

Name Type Description
modality_dims list[int]

Feature dimension for each modality.

latent_dim int

Shared latent space dimension.

hidden_dim int

Hidden layer width for all encoders / decoders.

modality_weight_mode str

How reconstruction losses are weighted. 'equal' gives uniform weight; 'learnable' uses softmax over a learnable log-weight vector.

Usage Examples¤

Spatial Deconvolution¤

from flax import nnx
from diffbio.operators.multiomics import SpatialDeconvolution, SpatialDeconvolutionConfig

config = SpatialDeconvolutionConfig(n_cell_types=10, n_genes=2000)
deconv = SpatialDeconvolution(config, rngs=nnx.Rngs(42))

data = {
    "spatial_expression": spot_expression,     # (n_spots, n_genes)
    "reference_profiles": cell_type_profiles,  # (n_cell_types, n_genes)
}
result, _, _ = deconv.apply(data, {}, None)
proportions = result["proportions"]

Hi-C Contact Analysis¤

from diffbio.operators.multiomics import HiCContactAnalysis, HiCContactAnalysisConfig

config = HiCContactAnalysisConfig(n_bins=1000, hidden_dim=64)
hic_analysis = HiCContactAnalysis(config, rngs=nnx.Rngs(42))

data = {"contact_matrix": hic_matrix}  # (n_bins, n_bins)
result, _, _ = hic_analysis.apply(data, {}, None)
compartments = result["compartments"]
tad_boundaries = result["tad_boundaries"]

Spatial Gene Detection¤

from flax import nnx
from diffbio.operators.multiomics import (
    DifferentiableSpatialGeneDetector,
    SpatialGeneDetectorConfig,
    create_spatial_gene_detector,
)

# Using config
config = SpatialGeneDetectorConfig(
    n_genes=2000,
    lengthscale=1.0,
    variance=1.0,
    pvalue_threshold=0.05,
)
detector = DifferentiableSpatialGeneDetector(config, rngs=nnx.Rngs(42))

# Or using factory function
detector = create_spatial_gene_detector(
    n_genes=2000,
    lengthscale=1.0,
)

# Apply spatial gene detection
data = {
    "spatial_coords": coords,        # (n_spots, 2)
    "expression": expression,        # (n_spots, n_genes)
    "total_counts": total_counts,    # (n_spots,) optional
}
result, _, _ = detector.apply(data, {}, None)

# Get spatial gene results
fsv = result["fsv"]                      # Fraction of Spatial Variance
is_spatial = result["is_spatial"]        # Soft spatial indicator
smoothed = result["smoothed_expression"] # GP-smoothed expression