Core Base Classes¤
DiffBio operators inherit from Datarax's base classes, providing a consistent interface for composable, differentiable data processing.
OperatorModule¤
All DiffBio operators inherit from datarax.core.operator.OperatorModule, which provides:
- Consistent
apply()interface for data transformation apply_batch()for batch processing- Integration with Flax NNX for learnable parameters
Interface¤
class OperatorModule:
def apply(
self,
data: PyTree,
state: PyTree,
metadata: dict | None,
random_params: Any = None,
stats: dict | None = None,
) -> tuple[PyTree, PyTree, dict | None]:
"""Transform data through the operator.
Args:
data: Input data (typically dict of arrays)
state: Per-element state
metadata: Optional element metadata
random_params: Random parameters for stochastic ops
stats: Optional statistics dictionary
Returns:
Tuple of (transformed_data, updated_state, updated_metadata)
"""
Usage Pattern¤
from diffbio.operators.alignment import SmoothSmithWaterman, SmithWatermanConfig
from diffbio.operators.alignment import create_dna_scoring_matrix
# Create operator
config = SmithWatermanConfig(temperature=1.0)
scoring = create_dna_scoring_matrix(match=2.0, mismatch=-1.0)
operator = SmoothSmithWaterman(config, scoring_matrix=scoring)
# Prepare data
data = {"seq1": seq1_tensor, "seq2": seq2_tensor}
state = {}
metadata = None
# Apply operator
result_data, state, metadata = operator.apply(data, state, metadata)
OperatorConfig¤
Configuration base class for operators from datarax.core.config.OperatorConfig:
from dataclasses import dataclass
from datarax.core.config import OperatorConfig
@dataclass(frozen=True)
class MyOperatorConfig(OperatorConfig):
"""Configuration for MyOperator.
Attributes:
my_param: Description of parameter.
"""
my_param: float = 1.0
DiffBio Operator Hierarchy¤
graph TB
A["datarax.core.operator.OperatorModule"] --> T["TemperatureOperator"]
T --> C["SmoothSmithWaterman"]
T --> D["DifferentiablePileup"]
T --> Q["DifferentiableQualityFilter"]
A --> E["VariantClassifier"]
A --> F["VariantCallingPipeline"]
S["diffbio.core.soft_ops"] -.->|"used by"| T
style A fill:#ede9fe,stroke:#7c3aed,color:#4c1d95
style T fill:#fef3c7,stroke:#d97706,color:#92400e
style S fill:#d1fae5,stroke:#059669,color:#065f46
style C fill:#e0e7ff,stroke:#4338ca,color:#312e81
style D fill:#e0e7ff,stroke:#4338ca,color:#312e81
style Q fill:#e0e7ff,stroke:#4338ca,color:#312e81
style E fill:#e0e7ff,stroke:#4338ca,color:#312e81
style F fill:#e0e7ff,stroke:#4338ca,color:#312e81
The soft_ops module provides the differentiable primitives (sorting, comparisons, selection) used by TemperatureOperator and its subclasses. See the Soft Operations API reference for the full list.
Learnable Parameters¤
DiffBio uses Flax NNX's nnx.Param for learnable parameters:
from flax import nnx
class MyOperator(OperatorModule):
def __init__(self, config, rngs):
super().__init__(config, rngs=rngs)
# Learnable parameters
self.temperature = nnx.Param(jnp.array(1.0))
self.threshold = nnx.Param(jnp.array(20.0))
Accessing Parameters¤
# Get all parameters
params = nnx.state(operator, nnx.Param)
# Access specific parameter value
value = operator.temperature[...]
# Update parameter
operator.temperature[...] = new_value
Gradient Computation¤
import jax
def loss_fn(operator, data):
result, _, _ = operator.apply(data, {}, None)
return result["score"]
# Compute gradients w.r.t. parameters
grads = jax.grad(loss_fn)(operator, data)
Graph Utilities¤
compute_pairwise_distances¤
diffbio.core.graph_utils.compute_pairwise_distances
¤
Compute pairwise distance matrix between all samples.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
features
|
Array
|
Input feature matrix of shape |
required |
metric
|
str
|
Distance metric, either |
'euclidean'
|
Returns:
| Type | Description |
|---|---|
Array
|
Distance matrix of shape |
Array
|
|
compute_knn_graph¤
diffbio.core.graph_utils.compute_knn_graph
¤
Build a k-nearest-neighbor graph from a dense distance matrix.
For each node the k closest neighbours (by distance) are selected. Self-connections are assumed to already be masked out by setting the diagonal to a large value before calling this function.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
distances
|
Array
|
Dense distance matrix of shape |
required |
k
|
int
|
Number of nearest neighbours per node. Clipped to |
required |
Returns:
| Type | Description |
|---|---|
Array
|
A tuple |
Array
|
|
tuple[Array, Array]
|
|
tuple[Array, Array]
|
|
compute_fuzzy_membership¤
diffbio.core.graph_utils.compute_fuzzy_membership
¤
Compute fuzzy set membership using a Gaussian kernel with local bandwidth.
The bandwidth (sigma) for each sample is set to the distance to its k-th nearest neighbour, making the kernel adapt to local density. The diagonal of the output is forced to zero (no self-similarity).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
distances
|
Array
|
Dense distance matrix of shape |
required |
k
|
int
|
Number of neighbours used to determine the local bandwidth.
Clipped to |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Fuzzy membership matrix of shape |
symmetrize_graph¤
diffbio.core.graph_utils.symmetrize_graph
¤
Symmetrize a directed adjacency matrix via fuzzy set union.
Applies the probabilistic (fuzzy) union:
p_sym = p + p^T - p * p^T
This ensures the output is symmetric and, when inputs are in [0, 1],
the outputs remain in [0, 1].
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
adjacency
|
Array
|
Directed adjacency / membership matrix of shape |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Symmetric adjacency matrix of shape |
GNN Components¤
GraphAttentionLayer¤
diffbio.core.gnn_components.GraphAttentionLayer
¤
GraphAttentionLayer(
in_features: int,
out_features: int,
num_heads: int,
edge_features: int,
dropout_rate: float,
*,
rngs: Rngs,
)
Bases: Module
Multi-head graph attention layer for message passing.
Computes attention-weighted message aggregation over graph edges. Each attention head independently computes query/key/value projections, adds an edge-feature bias to attention scores, normalizes via segment-softmax, and aggregates weighted values per target node.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_features
|
int
|
Input node feature dimension. |
required |
out_features
|
int
|
Output feature dimension (must be divisible by num_heads). |
required |
num_heads
|
int
|
Number of parallel attention heads. |
required |
edge_features
|
int
|
Edge feature dimension. |
required |
dropout_rate
|
float
|
Dropout rate applied to attention weights. |
required |
rngs
|
Rngs
|
Flax NNX random number generators. |
required |
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_features
|
int
|
Input feature dimension. |
required |
out_features
|
int
|
Output feature dimension. |
required |
num_heads
|
int
|
Number of attention heads. |
required |
edge_features
|
int
|
Edge feature dimension. |
required |
dropout_rate
|
float
|
Dropout rate. |
required |
rngs
|
Rngs
|
Random number generators. |
required |
__call__
¤
__call__(
node_features: Float[Array, "n_nodes in_features"],
edge_index: Int[Array, "2 n_edges"],
edge_features: Float[Array, "n_edges edge_features"],
*,
deterministic: bool = True,
) -> Float[Array, "n_nodes out_features"]
Run one graph-attention update step.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
node_features
|
Float[Array, 'n_nodes in_features']
|
Node feature matrix of shape |
required |
edge_index
|
Int[Array, '2 n_edges']
|
Edge indices |
required |
edge_features
|
Float[Array, 'n_edges edge_features']
|
Edge feature matrix of shape |
required |
deterministic
|
bool
|
Whether to disable stochastic dropout. |
True
|
Returns:
| Type | Description |
|---|---|
Float[Array, 'n_nodes out_features']
|
Updated node features of shape |
GraphAttentionBlock¤
diffbio.core.gnn_components.GraphAttentionBlock
¤
GraphAttentionBlock(
hidden_dim: int,
num_heads: int,
edge_features: int,
dropout_rate: float,
*,
rngs: Rngs,
)
Bases: Module
Full GNN block: graph attention + LayerNorm + residual + feedforward.
Combines a :class:GraphAttentionLayer with pre-norm residual connections
and a two-layer feedforward network (4x expansion), following the
standard Transformer block pattern adapted for graphs.
Architecture::
x -> GraphAttentionLayer -> + -> LayerNorm -> FFN -> + -> LayerNorm -> out
|___________________________| |_________|
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_dim
|
int
|
Hidden dimension (both input and output). |
required |
num_heads
|
int
|
Number of attention heads. |
required |
edge_features
|
int
|
Edge feature dimension. |
required |
dropout_rate
|
float
|
Dropout rate. |
required |
rngs
|
Rngs
|
Flax NNX random number generators. |
required |
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_dim
|
int
|
Hidden dimension. |
required |
num_heads
|
int
|
Number of attention heads. |
required |
edge_features
|
int
|
Edge feature dimension. |
required |
dropout_rate
|
float
|
Dropout rate. |
required |
rngs
|
Rngs
|
Random number generators. |
required |
__call__
¤
__call__(
node_features: Float[Array, "n_nodes hidden_dim"],
edge_index: Int[Array, "2 n_edges"],
edge_features: Float[Array, "n_edges edge_features"],
*,
deterministic: bool = True,
) -> Float[Array, "n_nodes hidden_dim"]
Apply the GNN block.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
node_features
|
Float[Array, 'n_nodes hidden_dim']
|
Node features of shape |
required |
edge_index
|
Int[Array, '2 n_edges']
|
Edge indices of shape |
required |
edge_features
|
Float[Array, 'n_edges edge_features']
|
Edge features of shape |
required |
deterministic
|
bool
|
If True, disable dropout. |
True
|
Returns:
| Type | Description |
|---|---|
Float[Array, 'n_nodes hidden_dim']
|
Updated node features of shape |
GATv2Layer¤
diffbio.core.gnn_components.GATv2Layer
¤
GATv2Layer(
in_features: int,
out_features: int,
num_heads: int,
edge_features: int,
dropout_rate: float,
negative_slope: float = 0.2,
*,
rngs: Rngs,
)
Bases: Module
GATv2 multi-head graph attention layer.
Unlike the original GAT (GraphAttentionLayer), GATv2 applies LeakyReLU
before computing the attention scalar, which makes the attention function
strictly more expressive (it can represent any monotonic scoring function
over concatenated source/target features).
GATv2 attention::
e_{ij} = a^T * LeakyReLU(W_l * h_i + W_r * h_j + edge_bias)
This is the key difference from GAT, where the nonlinearity is applied after the attention dot product.
Reference: Brody, Alon, Yahav. "How Attentive are Graph Attention Networks?" (ICLR 2022).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_features
|
int
|
Input node feature dimension. |
required |
out_features
|
int
|
Output feature dimension (must be divisible by num_heads). |
required |
num_heads
|
int
|
Number of parallel attention heads. |
required |
edge_features
|
int
|
Edge feature dimension. |
required |
dropout_rate
|
float
|
Dropout rate applied to attention weights. |
required |
negative_slope
|
float
|
Negative slope for LeakyReLU (default 0.2). |
0.2
|
rngs
|
Rngs
|
Flax NNX random number generators. |
required |
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_features
|
int
|
Input feature dimension. |
required |
out_features
|
int
|
Output feature dimension. |
required |
num_heads
|
int
|
Number of attention heads. |
required |
edge_features
|
int
|
Edge feature dimension. |
required |
dropout_rate
|
float
|
Dropout rate. |
required |
negative_slope
|
float
|
Negative slope for LeakyReLU. |
0.2
|
rngs
|
Rngs
|
Random number generators. |
required |
__call__
¤
__call__(
node_features: Float[Array, "n_nodes in_features"],
edge_index: Int[Array, "2 n_edges"],
edge_features: Float[Array, "n_edges edge_features"],
*,
deterministic: bool = True,
) -> Float[Array, "n_nodes out_features"]
Run one GATv2 attention update step.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
node_features
|
Float[Array, 'n_nodes in_features']
|
Node feature matrix |
required |
edge_index
|
Int[Array, '2 n_edges']
|
Edge indices |
required |
edge_features
|
Float[Array, 'n_edges edge_features']
|
Edge feature matrix |
required |
deterministic
|
bool
|
Whether to disable stochastic dropout. |
True
|
Returns:
| Type | Description |
|---|---|
Float[Array, 'n_nodes out_features']
|
Updated node features of shape |
GATv2Block¤
diffbio.core.gnn_components.GATv2Block
¤
GATv2Block(
hidden_dim: int,
num_heads: int,
edge_features: int,
dropout_rate: float,
negative_slope: float = 0.2,
*,
rngs: Rngs,
)
Bases: Module
Full GNN block using GATv2 attention + LayerNorm + residual + FFN.
Architecture::
x -> GATv2Layer -> + -> LayerNorm -> FFN -> + -> LayerNorm -> out
|__________________| |_________|
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_dim
|
int
|
Hidden dimension (both input and output). |
required |
num_heads
|
int
|
Number of attention heads. |
required |
edge_features
|
int
|
Edge feature dimension. |
required |
dropout_rate
|
float
|
Dropout rate. |
required |
negative_slope
|
float
|
Negative slope for LeakyReLU in GATv2 attention. |
0.2
|
rngs
|
Rngs
|
Flax NNX random number generators. |
required |
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_dim
|
int
|
Hidden dimension. |
required |
num_heads
|
int
|
Number of attention heads. |
required |
edge_features
|
int
|
Edge feature dimension. |
required |
dropout_rate
|
float
|
Dropout rate. |
required |
negative_slope
|
float
|
Negative slope for LeakyReLU. |
0.2
|
rngs
|
Rngs
|
Random number generators. |
required |
__call__
¤
__call__(
node_features: Float[Array, "n_nodes hidden_dim"],
edge_index: Int[Array, "2 n_edges"],
edge_features: Float[Array, "n_edges edge_features"],
*,
deterministic: bool = True,
) -> Float[Array, "n_nodes hidden_dim"]
Apply the GATv2 block.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
node_features
|
Float[Array, 'n_nodes hidden_dim']
|
Node features |
required |
edge_index
|
Int[Array, '2 n_edges']
|
Edge indices |
required |
edge_features
|
Float[Array, 'n_edges edge_features']
|
Edge features |
required |
deterministic
|
bool
|
If True, disable dropout. |
True
|
Returns:
| Type | Description |
|---|---|
Float[Array, 'n_nodes hidden_dim']
|
Updated node features |
Optimal Transport¤
SinkhornLayer¤
diffbio.core.optimal_transport.SinkhornLayer
¤
Bases: Module
Sinkhorn optimal transport layer (log-domain).
Computes the entropy-regularised optimal transport plan between two
discrete marginal distributions a and b given a cost matrix C,
by solving::
min_{P >= 0} <P, C> - epsilon * H(P)
s.t. P @ 1 = a, P^T @ 1 = b
The algorithm runs in log-domain for numerical stability::
f, g = 0, 0
for _ in range(num_iters):
f = epsilon * log(a) - epsilon * logsumexp((-C + g) / epsilon, axis=1)
g = epsilon * log(b) - epsilon * logsumexp((-C + f) / epsilon, axis=0)
P = exp((f[:, None] + g[None, :] - C) / epsilon)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
epsilon
|
float
|
Entropy regularisation strength (larger = smoother plan). |
required |
num_iters
|
int
|
Number of Sinkhorn iterations. |
required |
rngs
|
Rngs
|
Flax NNX random number generators (unused, kept for API consistency). |
required |
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
epsilon
|
float
|
Regularisation strength. |
required |
num_iters
|
int
|
Number of Sinkhorn iterations. |
required |
rngs
|
Rngs
|
Random number generators (for API consistency). |
required |
__call__
¤
__call__(
cost: Float[Array, "n m"],
a: Float[Array, " n"],
b: Float[Array, " m"],
) -> Float[Array, "n m"]
Compute the optimal transport plan.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
cost
|
Float[Array, 'n m']
|
Cost matrix of shape |
required |
a
|
Float[Array, ' n']
|
Source marginal distribution of shape |
required |
b
|
Float[Array, ' m']
|
Target marginal distribution of shape |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, 'n m']
|
Transport plan of shape |
Float[Array, 'n m']
|
|
Type Annotations¤
DiffBio uses jaxtyping for array type annotations:
from jaxtyping import Array, Float, Int, PyTree
# Float array with shape annotations
def align(
self,
seq1: Float[Array, "len1 alphabet"],
seq2: Float[Array, "len2 alphabet"],
) -> Float[Array, ""]:
...
Common type patterns:
| Annotation | Description |
|---|---|
Float[Array, ""] |
Scalar float |
Float[Array, "n"] |
1D float array |
Float[Array, "n m"] |
2D float array |
Int[Array, "n"] |
1D integer array |
PyTree |
JAX pytree (dict, tuple, etc.) |