Skip to content

Metabolomics Operators API¤

Differentiable operators for metabolomics analysis, including MS/MS spectral similarity.

DifferentiableSpectralSimilarity¤

diffbio.operators.metabolomics.spectral_similarity.DifferentiableSpectralSimilarity ¤

DifferentiableSpectralSimilarity(
    config: SpectralSimilarityConfig, *, rngs: Rngs
)

Bases: OperatorModule

Siamese neural network for spectral similarity prediction.

This operator implements the MS2DeepScore architecture for predicting molecular structural similarity from tandem mass spectra. The network uses a shared encoder to generate spectral embeddings, then computes cosine similarity between pairs of embeddings.

The operator supports two modes of operation: 1. Single spectrum input: Generates embeddings for spectra 2. Paired spectra input: Computes similarity between spectrum pairs

Architecture

Input (n_bins) -> shared encoder MLP -> Embedding (embedding_dim)

Attributes:

Name Type Description
config

SpectralSimilarityConfig with hyperparameters.

backbone

Shared Artifex encoder backbone producing spectral embeddings.

Example
config = SpectralSimilarityConfig(n_bins=1000, embedding_dim=200)
operator = DifferentiableSpectralSimilarity(config, rngs=nnx.Rngs(42))
# Get embeddings for spectra
spectra = jax.random.uniform(jax.random.PRNGKey(0), (10, 1000))
result, _, _ = operator.apply({"spectra": spectra}, {}, None)
embeddings = result["embeddings"]  # (10, 200)
# Compute pairwise similarity
spectra_a = jax.random.uniform(jax.random.PRNGKey(0), (5, 1000))
spectra_b = jax.random.uniform(jax.random.PRNGKey(1), (5, 1000))
result, _, _ = operator.apply(
    {"spectra_a": spectra_a, "spectra_b": spectra_b}, {}, None
)
similarity = result["similarity_scores"]  # (5,) in [-1, 1]

Parameters:

Name Type Description Default
config SpectralSimilarityConfig

Configuration with network hyperparameters.

required
rngs Rngs

Flax NNX random number generators.

required

apply ¤

apply(
    data: dict[str, Any],
    state: dict[str, Any],
    metadata: dict[str, Any] | None,
    random_params: Any = None,
    stats: dict[str, Any] | None = None,
) -> tuple[
    dict[str, Any], dict[str, Any], dict[str, Any] | None
]

Apply the spectral similarity operator.

The operator supports two input modes:

  1. Single spectra mode (embedding generation): Input: {"spectra": (n, n_bins)} Output: {"embeddings": (n, embedding_dim)}

  2. Paired spectra mode (similarity computation): Input: {"spectra_a": (n, n_bins), "spectra_b": (n, n_bins)} Output: {"similarity_scores": (n,), "embeddings_a": ..., "embeddings_b": ...}

Parameters:

Name Type Description Default
data dict[str, Any]

Input data dictionary with spectra.

required
state dict[str, Any]

Per-element state (passed through).

required
metadata dict[str, Any] | None

Optional metadata (passed through).

required
random_params Any

Random parameters (unused).

None
stats dict[str, Any] | None

Optional statistics (unused).

None

Returns:

Type Description
tuple[dict[str, Any], dict[str, Any], dict[str, Any] | None]

Tuple of (output_data, state, metadata).

encode ¤

encode(spectra: ndarray) -> ndarray

Encode binned spectra into embeddings.

BatchNorm and Dropout respect the model's train/eval mode: - Call model.train() before training to enable dropout and update batch stats - Call model.eval() before inference to disable dropout and use running stats

Parameters:

Name Type Description Default
spectra ndarray

Binned spectra with shape (n_spectra, n_bins).

required

Returns:

Type Description
ndarray

Embeddings with shape (n_spectra, embedding_dim).

cosine_similarity ¤

cosine_similarity(
    embeddings_a: ndarray, embeddings_b: ndarray
) -> ndarray

Compute cosine similarity between embedding pairs.

Parameters:

Name Type Description Default
embeddings_a ndarray

First set of embeddings (n, embedding_dim).

required
embeddings_b ndarray

Second set of embeddings (n, embedding_dim).

required

Returns:

Type Description
ndarray

Cosine similarity scores with shape (n,).

SpectralSimilarityConfig¤

diffbio.operators.metabolomics.spectral_similarity.SpectralSimilarityConfig dataclass ¤

SpectralSimilarityConfig(
    n_bins: int = 1000,
    embedding_dim: int = 200,
    hidden_dims: tuple[int, ...] = (512, 256),
    dropout_rate: float = 0.2,
    min_mz: float = 0.0,
    max_mz: float = 1000.0,
    use_batch_norm: bool = True,
)

Bases: OperatorConfig

Configuration for DifferentiableSpectralSimilarity.

Attributes:

Name Type Description
n_bins int

Number of m/z bins for spectrum discretization. Default 1000 (10-1000 m/z at 1 m/z resolution). Original MS2DeepScore uses 10000 bins at 0.1 m/z resolution.

embedding_dim int

Dimension of spectral embeddings. Default 200.

hidden_dims tuple[int, ...]

Tuple of hidden layer dimensions. Default (512, 256). Original MS2DeepScore uses (500, 500).

dropout_rate float

Dropout rate for regularization. Default 0.2.

min_mz float

Minimum m/z value for binning. Default 0.0.

max_mz float

Maximum m/z value for binning. Default 1000.0.

use_batch_norm bool

Whether to use batch normalization. Default True.

bin_spectrum¤

diffbio.operators.metabolomics.spectral_similarity.bin_spectrum ¤

bin_spectrum(
    mz_values: ndarray,
    intensities: ndarray,
    n_bins: int = 1000,
    min_mz: float = 0.0,
    max_mz: float = 1000.0,
    normalize: bool = True,
) -> ndarray

Bin a mass spectrum into fixed-width m/z bins.

This function discretizes a continuous mass spectrum (m/z, intensity pairs) into a fixed-size vector suitable for neural network input.

Parameters:

Name Type Description Default
mz_values ndarray

Array of m/z values with shape (n_peaks,).

required
intensities ndarray

Array of intensity values with shape (n_peaks,).

required
n_bins int

Number of bins to use. Default 1000.

1000
min_mz float

Minimum m/z value for binning. Default 0.0.

0.0
max_mz float

Maximum m/z value for binning. Default 1000.0.

1000.0
normalize bool

Whether to normalize intensities to max=1.0. Default True.

True

Returns:

Type Description
ndarray

Binned spectrum with shape (n_bins,).

Example

mz = jnp.array([100.0, 200.0, 300.0])
intensity = jnp.array([0.5, 1.0, 0.3])
binned = bin_spectrum(mz, intensity, n_bins=100)
binned.shape
(100,)

create_spectral_similarity¤

diffbio.operators.metabolomics.spectral_similarity.create_spectral_similarity ¤

create_spectral_similarity(
    n_bins: int = 1000,
    embedding_dim: int = 200,
    hidden_dims: tuple[int, ...] = (512, 256),
    dropout_rate: float = 0.2,
    seed: int = 42,
) -> DifferentiableSpectralSimilarity

Factory function to create a spectral similarity operator.

Parameters:

Name Type Description Default
n_bins int

Number of m/z bins. Default 1000.

1000
embedding_dim int

Embedding dimension. Default 200.

200
hidden_dims tuple[int, ...]

Hidden layer dimensions. Default (512, 256).

(512, 256)
dropout_rate float

Dropout rate. Default 0.2.

0.2
seed int

Random seed. Default 42.

42

Returns:

Type Description
DifferentiableSpectralSimilarity

Configured DifferentiableSpectralSimilarity operator.

Example
operator = create_spectral_similarity(n_bins=500, embedding_dim=128)
spectra = jax.random.uniform(jax.random.PRNGKey(0), (10, 500))
result, _, _ = operator.apply({"spectra": spectra}, {}, None)

Usage Examples¤

Basic Spectral Similarity¤

from flax import nnx
import jax
from diffbio.operators.metabolomics import (
    DifferentiableSpectralSimilarity,
    SpectralSimilarityConfig,
    create_spectral_similarity,
    bin_spectrum,
)

# Using config
config = SpectralSimilarityConfig(
    n_bins=1000,
    embedding_dim=200,
    hidden_dims=(512, 256),
)
operator = DifferentiableSpectralSimilarity(config, rngs=nnx.Rngs(42))

# Or using factory function
operator = create_spectral_similarity(n_bins=1000)

# Compute embeddings
spectra = jax.random.uniform(jax.random.PRNGKey(0), (100, 1000))
result, _, _ = operator.apply({"spectra": spectra}, {}, None)
embeddings = result["embeddings"]  # (100, 200)

Pairwise Similarity¤

# Compare pairs of spectra
spectra_a = jax.random.uniform(jax.random.PRNGKey(0), (50, 1000))
spectra_b = jax.random.uniform(jax.random.PRNGKey(1), (50, 1000))

data = {"spectra_a": spectra_a, "spectra_b": spectra_b}
result, _, _ = operator.apply(data, {}, None)

similarity = result["similarity_scores"]  # (50,) in [-1, 1]

Spectrum Binning¤

import jax.numpy as jnp
from diffbio.operators.metabolomics import bin_spectrum

# Raw mass spectrum
mz_values = jnp.array([100.0, 150.5, 200.0, 350.2, 500.0])
intensities = jnp.array([0.3, 1.0, 0.5, 0.8, 0.2])

# Discretize into bins
binned = bin_spectrum(
    mz_values,
    intensities,
    n_bins=1000,
    min_mz=0.0,
    max_mz=1000.0,
    normalize=True,
)
# binned.shape == (1000,)

Input Specifications¤

Single Spectra Mode¤

Key Shape Description
spectra (n_spectra, n_bins) Binned mass spectra

Paired Spectra Mode¤

Key Shape Description
spectra_a (n_pairs, n_bins) First set of binned spectra
spectra_b (n_pairs, n_bins) Second set of binned spectra

Output Specifications¤

Single Spectra Mode¤

Key Shape Description
spectra (n_spectra, n_bins) Original input spectra
embeddings (n_spectra, embedding_dim) Spectral embeddings

Paired Spectra Mode¤

Key Shape Description
spectra_a (n_pairs, n_bins) Original first spectra
spectra_b (n_pairs, n_bins) Original second spectra
embeddings_a (n_pairs, embedding_dim) First set embeddings
embeddings_b (n_pairs, embedding_dim) Second set embeddings
similarity_scores (n_pairs,) Cosine similarity scores