Skip to content

Foundation Model Operators API¤

Differentiable transformer-based operators for DNA/RNA sequence embedding and single-cell foundation-model workflows.

TransformerSequenceEncoder¤

diffbio.operators.foundation_models.transformer_encoder.TransformerSequenceEncoder ¤

TransformerSequenceEncoder(
    config: TransformerSequenceEncoderConfig,
    *,
    rngs: Rngs | None = None,
    name: str | None = None,
)

Bases: FoundationEmbeddingMixin, SequenceOperator

Transformer-based encoder for DNA/RNA sequences.

This operator implements a BERT-style transformer encoder that converts nucleotide sequences into dense embeddings. The architecture follows DNABERT and RNA-FM patterns.

Uses artifex's TransformerEncoder for the core transformer layers, following the DRY principle.

Supports two input embedding modes:

  • "linear" (default): Projects one-hot encoded input (seq_len, alphabet_size) via nnx.Linear. This is the standard mode for continuous one-hot input.
  • "token_embedding": Embeds integer token IDs (seq_len,) via nnx.Embed. Useful for gene-token foundation models and tokenized input.

The encoder produces:

  • Global sequence embedding via mean pooling or CLS token
  • Per-position embeddings for fine-grained analysis

Parameters:

Name Type Description Default
config TransformerSequenceEncoderConfig

TransformerSequenceEncoderConfig with model parameters.

required
rngs Rngs | None

Flax NNX random number generators.

None
name str | None

Optional operator name.

None
Example
config = TransformerSequenceEncoderConfig(hidden_dim=256)
encoder = TransformerSequenceEncoder(config, rngs=nnx.Rngs(42))
data = {"sequence": one_hot_sequence}
result, state, meta = encoder.apply(data, {}, None)
embeddings = result["embeddings"]

Parameters:

Name Type Description Default
config TransformerSequenceEncoderConfig

Encoder configuration.

required
rngs Rngs | None

Random number generators for initialization.

None
name str | None

Optional operator name.

None

apply ¤

apply(
    data: PyTree,
    state: PyTree,
    metadata: dict[str, Any] | None,
    random_params: Any = None,
    stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]

Apply transformer encoding to sequence data.

This method encodes DNA/RNA sequences into dense embeddings using a transformer encoder architecture.

Input shape depends on input_embedding_type:

  • "linear": one-hot (seq_len, alphabet_size) or (batch, seq_len, alphabet_size)
  • "token_embedding": integer token IDs (seq_len,) or (batch, seq_len)

Parameters:

Name Type Description Default
data PyTree

Dictionary containing: - "sequence": Encoded sequence(s) (see above for shapes) - "attention_mask": Optional mask (seq_len,) or (batch, seq_len)

required
state PyTree

Element state (passed through unchanged)

required
metadata dict[str, Any] | None

Element metadata (passed through unchanged)

required
random_params Any

Not used

None
stats dict[str, Any] | None

Not used

None

Returns:

Type Description
tuple[PyTree, PyTree, dict[str, Any] | None]

Tuple of (transformed_data, state, metadata): - transformed_data contains:

- All original keys from data
- "embeddings": Global sequence embedding
- "token_embeddings": Per-position hidden states
- "foundation_model": Canonical artifact metadata
  • state is passed through unchanged
  • metadata is passed through unchanged

get_positional_encoding ¤

get_positional_encoding(
    seq_len: int,
) -> Float[Array, "seq_len hidden_dim"]

Generate sinusoidal positional encoding.

This is provided for compatibility but the transformer uses internal positional encoding.

Parameters:

Name Type Description Default
seq_len int

Sequence length.

required

Returns:

Type Description
Float[Array, 'seq_len hidden_dim']

Positional encoding matrix.

TransformerSequenceEncoderConfig¤

diffbio.operators.foundation_models.transformer_encoder.TransformerSequenceEncoderConfig dataclass ¤

TransformerSequenceEncoderConfig(
    adapter_mode: AdapterMode = NATIVE_TRAINABLE,
    artifact_id: str = "diffbio.transformer_sequence_encoder",
    preprocessing_version: str = "one_hot_v1",
    dropout_rate: float = 0.1,
    pooling: Literal["mean", "cls"] = "mean",
    alphabet_size: int = 4,
    input_embedding_type: Literal[
        "linear", "token_embedding"
    ] = "linear",
    vocab_size: int | None = None,
    hidden_dim: int = 256,
    num_layers: int = 4,
    num_heads: int = 4,
    intermediate_dim: int = 1024,
    max_length: int = 512,
)

Bases: _TransformerArchitectureConfig, _TransformerInputConfig, _TransformerOutputConfig, TransformerEncoderShapeValidationMixin, FoundationEmbeddingOperatorConfig

Configuration for TransformerSequenceEncoder.

adapter_mode class-attribute instance-attribute ¤

adapter_mode: AdapterMode = NATIVE_TRAINABLE

artifact_id class-attribute instance-attribute ¤

artifact_id: str = 'diffbio.transformer_sequence_encoder'

preprocessing_version class-attribute instance-attribute ¤

preprocessing_version: str = 'one_hot_v1'

hidden_dim class-attribute instance-attribute ¤

hidden_dim: int = 256

num_layers class-attribute instance-attribute ¤

num_layers: int = 4

num_heads class-attribute instance-attribute ¤

num_heads: int = 4

intermediate_dim class-attribute instance-attribute ¤

intermediate_dim: int = 1024

max_length class-attribute instance-attribute ¤

max_length: int = 512

dropout_rate class-attribute instance-attribute ¤

dropout_rate: float = 0.1

pooling class-attribute instance-attribute ¤

pooling: Literal['mean', 'cls'] = 'mean'

alphabet_size class-attribute instance-attribute ¤

alphabet_size: int = 4

input_embedding_type class-attribute instance-attribute ¤

input_embedding_type: Literal[
    "linear", "token_embedding"
] = "linear"

vocab_size class-attribute instance-attribute ¤

vocab_size: int | None = None

Factory Functions¤

create_dna_encoder¤

diffbio.operators.foundation_models.transformer_encoder.create_dna_encoder ¤

create_dna_encoder(
    hidden_dim: int = 256,
    num_layers: int = 4,
    num_heads: int = 4,
    intermediate_dim: int | None = None,
    max_length: int = 512,
    dropout_rate: float = 0.1,
    pooling: Literal["mean", "cls"] = "mean",
    *,
    rngs: Rngs | None = None,
) -> TransformerSequenceEncoder

Create a transformer encoder for DNA sequences.

Factory function for creating a DNA sequence encoder with sensible defaults for DNA processing.

Parameters:

Name Type Description Default
hidden_dim int

Dimension of hidden states.

256
num_layers int

Number of transformer layers.

4
num_heads int

Number of attention heads.

4
intermediate_dim int | None

FFN intermediate dimension (default: 4 * hidden_dim).

None
max_length int

Maximum sequence length.

512
dropout_rate float

Dropout rate.

0.1
pooling Literal['mean', 'cls']

Pooling strategy.

'mean'
rngs Rngs | None

Random number generators.

None

Returns:

Type Description
TransformerSequenceEncoder

Configured TransformerSequenceEncoder for DNA.

Example
encoder = create_dna_encoder(hidden_dim=256, num_layers=6)
data = {"sequence": dna_one_hot}
result, _, _ = encoder.apply(data, {}, None)
embeddings = result["embeddings"]

create_rna_encoder¤

diffbio.operators.foundation_models.transformer_encoder.create_rna_encoder ¤

create_rna_encoder(
    hidden_dim: int = 256,
    num_layers: int = 4,
    num_heads: int = 4,
    intermediate_dim: int | None = None,
    max_length: int = 512,
    dropout_rate: float = 0.1,
    pooling: Literal["mean", "cls"] = "mean",
    *,
    rngs: Rngs | None = None,
) -> TransformerSequenceEncoder

Create a transformer encoder for RNA sequences.

Factory function for creating an RNA sequence encoder with sensible defaults for RNA processing.

Parameters:

Name Type Description Default
hidden_dim int

Dimension of hidden states.

256
num_layers int

Number of transformer layers.

4
num_heads int

Number of attention heads.

4
intermediate_dim int | None

FFN intermediate dimension (default: 4 * hidden_dim).

None
max_length int

Maximum sequence length.

512
dropout_rate float

Dropout rate.

0.1
pooling Literal['mean', 'cls']

Pooling strategy.

'mean'
rngs Rngs | None

Random number generators.

None

Returns:

Type Description
TransformerSequenceEncoder

Configured TransformerSequenceEncoder for RNA.

Example
encoder = create_rna_encoder(hidden_dim=640, num_layers=12)
data = {"sequence": rna_one_hot}
result, _, _ = encoder.apply(data, {}, None)

DifferentiableFoundationModel¤

diffbio.operators.foundation_models.foundation_model.DifferentiableFoundationModel ¤

DifferentiableFoundationModel(
    config: FoundationModelConfig,
    *,
    rngs: Rngs | None = None,
    name: str | None = None,
)

Bases: FoundationEmbeddingMixin, MaskedGeneTransformerOperatorMixin, OperatorModule

Differentiable single-cell foundation model operator.

Implements a masked gene expression prediction model inspired by Geneformer (rank-value tokenization) and scGPT (masked expression prediction with gene + value embeddings).

Algorithm:

  1. Tokenize: Rank genes by expression per cell via soft sort (Geneformer-style, used for gene embedding ordering context).
  2. Embed gene IDs via TransformerSequenceEncoder with input_embedding_type="token_embedding".
  3. Add expression value projection: scalar expression values are projected to hidden_dim and added to gene embeddings (scGPT-style).
  4. Random mask: mask_ratio fraction of genes have their expression embeddings replaced with a learned mask token.
  5. Transformer encoder: contextualizes gene representations.
  6. Predict: linear output head predicts masked gene expression.
  7. Cell embedding: mean pooling of non-masked gene representations.

Parameters:

Name Type Description Default
config FoundationModelConfig

FoundationModelConfig with model parameters.

required
rngs Rngs | None

Flax NNX random number generators.

None
name str | None

Optional operator name.

None
Example

config = FoundationModelConfig(n_genes=2000, hidden_dim=128) model = DifferentiableFoundationModel( ... config, rngs=nnx.Rngs(params=0, sample=1, dropout=2)) rp = model.generate_random_params( ... jax.random.key(0), {"counts": (100, 2000)}) data = {"counts": counts, "gene_ids": jnp.arange(2000)} result, state, meta = model.apply(data, {}, None, random_params=rp)

Parameters:

Name Type Description Default
config FoundationModelConfig

Foundation model configuration.

required
rngs Rngs | None

Random number generators for parameter initialization.

None
name str | None

Optional operator name.

None

apply ¤

apply(
    data: PyTree,
    state: PyTree,
    metadata: dict[str, Any] | None,
    random_params: Any = None,
    stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]

Apply foundation model to single-cell count data.

Parameters:

Name Type Description Default
data PyTree

Dictionary containing: - "counts": Gene expression matrix (n_cells, n_genes) - "gene_ids": Integer gene IDs (n_genes,)

required
state PyTree

Element state (passed through unchanged).

required
metadata dict[str, Any] | None

Element metadata (passed through unchanged).

required
random_params Any

JAX random key for mask generation.

None
stats dict[str, Any] | None

Not used.

None

Returns:

Type Description
tuple[PyTree, PyTree, dict[str, Any] | None]

Tuple of (transformed_data, state, metadata): - transformed_data contains:

- All original keys from data
- ``"embeddings"``: Cell embeddings ``(n_cells, hidden_dim)``
- ``"token_embeddings"``: Contextual gene embeddings
  ``(n_cells, n_genes, hidden_dim)``
- ``"predicted_expression"``: Predicted expression ``(n_cells, n_genes)``
- ``"foundation_model"``: Canonical artifact metadata
  • state is passed through unchanged
  • metadata is passed through unchanged

FoundationModelConfig¤

diffbio.operators.foundation_models.foundation_model.FoundationModelConfig dataclass ¤

FoundationModelConfig(
    n_genes: int = 2000,
    hidden_dim: int = 128,
    num_layers: int = 2,
    num_heads: int = 4,
    mask_ratio: float = 0.15,
    dropout_rate: float = 0.1,
    adapter_mode: AdapterMode = NATIVE_TRAINABLE,
    artifact_id: str = "diffbio.differentiable_foundation_model",
    preprocessing_version: str = "counts_gene_ids_v1",
)

Bases: FoundationEmbeddingOperatorConfig, MaskedGeneTransformerConfigBase

Configuration for DifferentiableFoundationModel.

Attributes:

Name Type Description
adapter_mode AdapterMode

Integration mode for the model artifact.

artifact_id str

Identifier for the model artifact/version.

preprocessing_version str

Version tag for count/gene-ID preprocessing.

n_genes class-attribute instance-attribute ¤

n_genes: int = 2000

hidden_dim class-attribute instance-attribute ¤

hidden_dim: int = 128

num_layers class-attribute instance-attribute ¤

num_layers: int = 2

num_heads class-attribute instance-attribute ¤

num_heads: int = 4

mask_ratio class-attribute instance-attribute ¤

mask_ratio: float = 0.15

dropout_rate class-attribute instance-attribute ¤

dropout_rate: float = 0.1

adapter_mode class-attribute instance-attribute ¤

adapter_mode: AdapterMode = NATIVE_TRAINABLE

__post_init__ ¤

__post_init__() -> None

Validate foundation-model metadata fields.

GeneTokenizer¤

diffbio.operators.foundation_models.foundation_model.GeneTokenizer ¤

GeneTokenizer(n_genes: int, *, rngs: Rngs)

Bases: Module

Geneformer-style rank-value gene tokenizer.

Converts gene expression vectors into rank-ordered representations using a differentiable soft sort approximation. For each cell, genes are ranked by expression value in descending order. The output is a soft permutation matrix of shape (n_genes, n_genes) where row i is a soft one-hot indicating which gene occupies rank i.

The key insight from Geneformer: token IDs are gene indices sorted by expression magnitude. We approximate the discrete argsort with a temperature-controlled soft permutation to maintain differentiability.

Parameters:

Name Type Description Default
n_genes int

Number of genes.

required
rngs Rngs

Flax NNX random number generators.

required

Parameters:

Name Type Description Default
n_genes int

Number of genes in the vocabulary.

required
rngs Rngs

Random number generators (unused, kept for NNX API).

required

__call__ ¤

__call__(
    expression: Float[Array, n_genes],
    temperature: float = 1.0,
) -> Float[Array, "n_genes n_genes"]

Compute soft permutation matrix from expression values.

Genes are ranked in descending order of expression. The returned matrix P has shape (n_genes, n_genes) where P[i, j] approximates the probability that gene j occupies rank i.

At low temperature this approaches a hard permutation matrix (the true argsort).

Parameters:

Name Type Description Default
expression Float[Array, n_genes]

Gene expression values for one cell, shape (n_genes,).

required
temperature float

Softmax temperature (lower is sharper).

1.0

Returns:

Type Description
Float[Array, 'n_genes n_genes']

Soft permutation matrix of shape (n_genes, n_genes).

Usage Examples¤

Basic Usage¤

from diffbio.operators.foundation_models import create_dna_encoder
import jax
import jax.numpy as jnp

# Create encoder
encoder = create_dna_encoder(hidden_dim=256, num_layers=4)

# Prepare one-hot encoded sequence
sequence = jax.nn.one_hot(
    jax.random.randint(jax.random.PRNGKey(0), (100,), 0, 4),
    num_classes=4,
)

# Apply
result, _, _ = encoder.apply({"sequence": sequence}, {}, None)
embedding = result["embeddings"]  # (256,)

Full Configuration¤

from diffbio.operators.foundation_models import (
    TransformerSequenceEncoder,
    TransformerSequenceEncoderConfig,
)
from flax import nnx

config = TransformerSequenceEncoderConfig(
    hidden_dim=640,
    num_layers=12,
    num_heads=20,
    intermediate_dim=5120,
    max_length=1024,
    alphabet_size=4,
    dropout_rate=0.1,
    pooling="cls",
)

encoder = TransformerSequenceEncoder(config, rngs=nnx.Rngs(42))

Batched Processing¤

import jax.numpy as jnp

# Batch of sequences
batch_size = 8
seq_len = 100
sequences = jax.nn.one_hot(
    jax.random.randint(jax.random.PRNGKey(0), (batch_size, seq_len), 0, 4),
    num_classes=4,
)

result, _, _ = encoder.apply({"sequence": sequences}, {}, None)
embeddings = result["embeddings"]  # (8, 256)

Gradient Computation¤

import jax
from flax import nnx

encoder = create_dna_encoder()

def loss_fn(model, sequence):
    result, _, _ = model.apply({"sequence": sequence}, {}, None)
    return result["embeddings"].sum()

# Compute gradients w.r.t. model parameters
_, grads = nnx.value_and_grad(loss_fn)(encoder, sequence)

Input Specifications¤

Key Shape Type Description
sequence (length, 4) or (batch, length, 4) float32 One-hot encoded nucleotide sequence
attention_mask (length,) or (batch, length) float32 Optional mask (1=valid, 0=padded)

Output Specifications¤

Key Shape Type Description
sequence same as input float32 Original input sequence
embeddings (hidden_dim,) or (batch, hidden_dim) float32 Global sequence embedding
token_embeddings (length, hidden_dim) or (batch, length, hidden_dim) float32 Per-position hidden states
foundation_model metadata dict dict[str, uint8 array] JIT-safe artifact metadata for model family, artifact ID, preprocessing version, adapter mode, and pooling strategy

Reference Configurations¤

Model hidden_dim num_layers num_heads intermediate_dim
DNABERT 768 12 12 3072
RNA-FM 640 12 20 5120
Default 256 4 4 1024