Foundation Model Operators API¤
Differentiable transformer-based operators for DNA/RNA sequence embedding and single-cell foundation-model workflows.
TransformerSequenceEncoder¤
diffbio.operators.foundation_models.transformer_encoder.TransformerSequenceEncoder
¤
TransformerSequenceEncoder(
config: TransformerSequenceEncoderConfig,
*,
rngs: Rngs | None = None,
name: str | None = None,
)
Bases: FoundationEmbeddingMixin, SequenceOperator
Transformer-based encoder for DNA/RNA sequences.
This operator implements a BERT-style transformer encoder that converts nucleotide sequences into dense embeddings. The architecture follows DNABERT and RNA-FM patterns.
Uses artifex's TransformerEncoder for the core transformer layers, following the DRY principle.
Supports two input embedding modes:
- "linear" (default): Projects one-hot encoded input (seq_len, alphabet_size) via nnx.Linear. This is the standard mode for continuous one-hot input.
- "token_embedding": Embeds integer token IDs (seq_len,) via nnx.Embed. Useful for gene-token foundation models and tokenized input.
The encoder produces:
- Global sequence embedding via mean pooling or CLS token
- Per-position embeddings for fine-grained analysis
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
TransformerSequenceEncoderConfig
|
TransformerSequenceEncoderConfig with model parameters. |
required |
rngs
|
Rngs | None
|
Flax NNX random number generators. |
None
|
name
|
str | None
|
Optional operator name. |
None
|
Example
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
TransformerSequenceEncoderConfig
|
Encoder configuration. |
required |
rngs
|
Rngs | None
|
Random number generators for initialization. |
None
|
name
|
str | None
|
Optional operator name. |
None
|
apply
¤
apply(
data: PyTree,
state: PyTree,
metadata: dict[str, Any] | None,
random_params: Any = None,
stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]
Apply transformer encoding to sequence data.
This method encodes DNA/RNA sequences into dense embeddings using a transformer encoder architecture.
Input shape depends on input_embedding_type:
- "linear": one-hot
(seq_len, alphabet_size)or(batch, seq_len, alphabet_size) - "token_embedding": integer token IDs
(seq_len,)or(batch, seq_len)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
PyTree
|
Dictionary containing: - "sequence": Encoded sequence(s) (see above for shapes) - "attention_mask": Optional mask (seq_len,) or (batch, seq_len) |
required |
state
|
PyTree
|
Element state (passed through unchanged) |
required |
metadata
|
dict[str, Any] | None
|
Element metadata (passed through unchanged) |
required |
random_params
|
Any
|
Not used |
None
|
stats
|
dict[str, Any] | None
|
Not used |
None
|
Returns:
| Type | Description |
|---|---|
tuple[PyTree, PyTree, dict[str, Any] | None]
|
Tuple of (transformed_data, state, metadata): - transformed_data contains:
|
get_positional_encoding
¤
Generate sinusoidal positional encoding.
This is provided for compatibility but the transformer uses internal positional encoding.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
seq_len
|
int
|
Sequence length. |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, 'seq_len hidden_dim']
|
Positional encoding matrix. |
TransformerSequenceEncoderConfig¤
diffbio.operators.foundation_models.transformer_encoder.TransformerSequenceEncoderConfig
dataclass
¤
TransformerSequenceEncoderConfig(
adapter_mode: AdapterMode = NATIVE_TRAINABLE,
artifact_id: str = "diffbio.transformer_sequence_encoder",
preprocessing_version: str = "one_hot_v1",
dropout_rate: float = 0.1,
pooling: Literal["mean", "cls"] = "mean",
alphabet_size: int = 4,
input_embedding_type: Literal[
"linear", "token_embedding"
] = "linear",
vocab_size: int | None = None,
hidden_dim: int = 256,
num_layers: int = 4,
num_heads: int = 4,
intermediate_dim: int = 1024,
max_length: int = 512,
)
Bases: _TransformerArchitectureConfig, _TransformerInputConfig, _TransformerOutputConfig, TransformerEncoderShapeValidationMixin, FoundationEmbeddingOperatorConfig
Configuration for TransformerSequenceEncoder.
Factory Functions¤
create_dna_encoder¤
diffbio.operators.foundation_models.transformer_encoder.create_dna_encoder
¤
create_dna_encoder(
hidden_dim: int = 256,
num_layers: int = 4,
num_heads: int = 4,
intermediate_dim: int | None = None,
max_length: int = 512,
dropout_rate: float = 0.1,
pooling: Literal["mean", "cls"] = "mean",
*,
rngs: Rngs | None = None,
) -> TransformerSequenceEncoder
Create a transformer encoder for DNA sequences.
Factory function for creating a DNA sequence encoder with sensible defaults for DNA processing.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_dim
|
int
|
Dimension of hidden states. |
256
|
num_layers
|
int
|
Number of transformer layers. |
4
|
num_heads
|
int
|
Number of attention heads. |
4
|
intermediate_dim
|
int | None
|
FFN intermediate dimension (default: 4 * hidden_dim). |
None
|
max_length
|
int
|
Maximum sequence length. |
512
|
dropout_rate
|
float
|
Dropout rate. |
0.1
|
pooling
|
Literal['mean', 'cls']
|
Pooling strategy. |
'mean'
|
rngs
|
Rngs | None
|
Random number generators. |
None
|
Returns:
| Type | Description |
|---|---|
TransformerSequenceEncoder
|
Configured TransformerSequenceEncoder for DNA. |
create_rna_encoder¤
diffbio.operators.foundation_models.transformer_encoder.create_rna_encoder
¤
create_rna_encoder(
hidden_dim: int = 256,
num_layers: int = 4,
num_heads: int = 4,
intermediate_dim: int | None = None,
max_length: int = 512,
dropout_rate: float = 0.1,
pooling: Literal["mean", "cls"] = "mean",
*,
rngs: Rngs | None = None,
) -> TransformerSequenceEncoder
Create a transformer encoder for RNA sequences.
Factory function for creating an RNA sequence encoder with sensible defaults for RNA processing.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_dim
|
int
|
Dimension of hidden states. |
256
|
num_layers
|
int
|
Number of transformer layers. |
4
|
num_heads
|
int
|
Number of attention heads. |
4
|
intermediate_dim
|
int | None
|
FFN intermediate dimension (default: 4 * hidden_dim). |
None
|
max_length
|
int
|
Maximum sequence length. |
512
|
dropout_rate
|
float
|
Dropout rate. |
0.1
|
pooling
|
Literal['mean', 'cls']
|
Pooling strategy. |
'mean'
|
rngs
|
Rngs | None
|
Random number generators. |
None
|
Returns:
| Type | Description |
|---|---|
TransformerSequenceEncoder
|
Configured TransformerSequenceEncoder for RNA. |
DifferentiableFoundationModel¤
diffbio.operators.foundation_models.foundation_model.DifferentiableFoundationModel
¤
DifferentiableFoundationModel(
config: FoundationModelConfig,
*,
rngs: Rngs | None = None,
name: str | None = None,
)
Bases: FoundationEmbeddingMixin, MaskedGeneTransformerOperatorMixin, OperatorModule
Differentiable single-cell foundation model operator.
Implements a masked gene expression prediction model inspired by Geneformer (rank-value tokenization) and scGPT (masked expression prediction with gene + value embeddings).
Algorithm:
- Tokenize: Rank genes by expression per cell via soft sort (Geneformer-style, used for gene embedding ordering context).
- Embed gene IDs via
TransformerSequenceEncoderwithinput_embedding_type="token_embedding". - Add expression value projection: scalar expression values are
projected to
hidden_dimand added to gene embeddings (scGPT-style). - Random mask:
mask_ratiofraction of genes have their expression embeddings replaced with a learned mask token. - Transformer encoder: contextualizes gene representations.
- Predict: linear output head predicts masked gene expression.
- Cell embedding: mean pooling of non-masked gene representations.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
FoundationModelConfig
|
FoundationModelConfig with model parameters. |
required |
rngs
|
Rngs | None
|
Flax NNX random number generators. |
None
|
name
|
str | None
|
Optional operator name. |
None
|
Example
config = FoundationModelConfig(n_genes=2000, hidden_dim=128) model = DifferentiableFoundationModel( ... config, rngs=nnx.Rngs(params=0, sample=1, dropout=2)) rp = model.generate_random_params( ... jax.random.key(0), {"counts": (100, 2000)}) data = {"counts": counts, "gene_ids": jnp.arange(2000)} result, state, meta = model.apply(data, {}, None, random_params=rp)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
FoundationModelConfig
|
Foundation model configuration. |
required |
rngs
|
Rngs | None
|
Random number generators for parameter initialization. |
None
|
name
|
str | None
|
Optional operator name. |
None
|
apply
¤
apply(
data: PyTree,
state: PyTree,
metadata: dict[str, Any] | None,
random_params: Any = None,
stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]
Apply foundation model to single-cell count data.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
PyTree
|
Dictionary containing:
- |
required |
state
|
PyTree
|
Element state (passed through unchanged). |
required |
metadata
|
dict[str, Any] | None
|
Element metadata (passed through unchanged). |
required |
random_params
|
Any
|
JAX random key for mask generation. |
None
|
stats
|
dict[str, Any] | None
|
Not used. |
None
|
Returns:
| Type | Description |
|---|---|
tuple[PyTree, PyTree, dict[str, Any] | None]
|
Tuple of (transformed_data, state, metadata): - transformed_data contains:
|
FoundationModelConfig¤
diffbio.operators.foundation_models.foundation_model.FoundationModelConfig
dataclass
¤
FoundationModelConfig(
n_genes: int = 2000,
hidden_dim: int = 128,
num_layers: int = 2,
num_heads: int = 4,
mask_ratio: float = 0.15,
dropout_rate: float = 0.1,
adapter_mode: AdapterMode = NATIVE_TRAINABLE,
artifact_id: str = "diffbio.differentiable_foundation_model",
preprocessing_version: str = "counts_gene_ids_v1",
)
Bases: FoundationEmbeddingOperatorConfig, MaskedGeneTransformerConfigBase
Configuration for DifferentiableFoundationModel.
Attributes:
| Name | Type | Description |
|---|---|---|
adapter_mode |
AdapterMode
|
Integration mode for the model artifact. |
artifact_id |
str
|
Identifier for the model artifact/version. |
preprocessing_version |
str
|
Version tag for count/gene-ID preprocessing. |
GeneTokenizer¤
diffbio.operators.foundation_models.foundation_model.GeneTokenizer
¤
Bases: Module
Geneformer-style rank-value gene tokenizer.
Converts gene expression vectors into rank-ordered representations using
a differentiable soft sort approximation. For each cell, genes are ranked
by expression value in descending order. The output is a soft permutation
matrix of shape (n_genes, n_genes) where row i is a soft one-hot
indicating which gene occupies rank i.
The key insight from Geneformer: token IDs are gene indices sorted by expression magnitude. We approximate the discrete argsort with a temperature-controlled soft permutation to maintain differentiability.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_genes
|
int
|
Number of genes. |
required |
rngs
|
Rngs
|
Flax NNX random number generators. |
required |
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_genes
|
int
|
Number of genes in the vocabulary. |
required |
rngs
|
Rngs
|
Random number generators (unused, kept for NNX API). |
required |
__call__
¤
__call__(
expression: Float[Array, n_genes],
temperature: float = 1.0,
) -> Float[Array, "n_genes n_genes"]
Compute soft permutation matrix from expression values.
Genes are ranked in descending order of expression. The returned
matrix P has shape (n_genes, n_genes) where P[i, j]
approximates the probability that gene j occupies rank i.
At low temperature this approaches a hard permutation matrix (the true argsort).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
expression
|
Float[Array, n_genes]
|
Gene expression values for one cell, shape |
required |
temperature
|
float
|
Softmax temperature (lower is sharper). |
1.0
|
Returns:
| Type | Description |
|---|---|
Float[Array, 'n_genes n_genes']
|
Soft permutation matrix of shape |
Usage Examples¤
Basic Usage¤
from diffbio.operators.foundation_models import create_dna_encoder
import jax
import jax.numpy as jnp
# Create encoder
encoder = create_dna_encoder(hidden_dim=256, num_layers=4)
# Prepare one-hot encoded sequence
sequence = jax.nn.one_hot(
jax.random.randint(jax.random.PRNGKey(0), (100,), 0, 4),
num_classes=4,
)
# Apply
result, _, _ = encoder.apply({"sequence": sequence}, {}, None)
embedding = result["embeddings"] # (256,)
Full Configuration¤
from diffbio.operators.foundation_models import (
TransformerSequenceEncoder,
TransformerSequenceEncoderConfig,
)
from flax import nnx
config = TransformerSequenceEncoderConfig(
hidden_dim=640,
num_layers=12,
num_heads=20,
intermediate_dim=5120,
max_length=1024,
alphabet_size=4,
dropout_rate=0.1,
pooling="cls",
)
encoder = TransformerSequenceEncoder(config, rngs=nnx.Rngs(42))
Batched Processing¤
import jax.numpy as jnp
# Batch of sequences
batch_size = 8
seq_len = 100
sequences = jax.nn.one_hot(
jax.random.randint(jax.random.PRNGKey(0), (batch_size, seq_len), 0, 4),
num_classes=4,
)
result, _, _ = encoder.apply({"sequence": sequences}, {}, None)
embeddings = result["embeddings"] # (8, 256)
Gradient Computation¤
import jax
from flax import nnx
encoder = create_dna_encoder()
def loss_fn(model, sequence):
result, _, _ = model.apply({"sequence": sequence}, {}, None)
return result["embeddings"].sum()
# Compute gradients w.r.t. model parameters
_, grads = nnx.value_and_grad(loss_fn)(encoder, sequence)
Input Specifications¤
| Key | Shape | Type | Description |
|---|---|---|---|
sequence |
(length, 4) or (batch, length, 4) | float32 | One-hot encoded nucleotide sequence |
attention_mask |
(length,) or (batch, length) | float32 | Optional mask (1=valid, 0=padded) |
Output Specifications¤
| Key | Shape | Type | Description |
|---|---|---|---|
sequence |
same as input | float32 | Original input sequence |
embeddings |
(hidden_dim,) or (batch, hidden_dim) | float32 | Global sequence embedding |
token_embeddings |
(length, hidden_dim) or (batch, length, hidden_dim) | float32 | Per-position hidden states |
foundation_model |
metadata dict | dict[str, uint8 array] | JIT-safe artifact metadata for model family, artifact ID, preprocessing version, adapter mode, and pooling strategy |
Reference Configurations¤
| Model | hidden_dim | num_layers | num_heads | intermediate_dim |
|---|---|---|---|---|
| DNABERT | 768 | 12 | 12 | 3072 |
| RNA-FM | 640 | 12 | 20 | 5120 |
| Default | 256 | 4 | 4 | 1024 |