Skip to content

Core Base Classes¤

DiffBio operators inherit from Datarax's base classes, providing a consistent interface for composable, differentiable data processing.

OperatorModule¤

All DiffBio operators inherit from datarax.core.operator.OperatorModule, which provides:

  • Consistent apply() interface for data transformation
  • apply_batch() for batch processing
  • Integration with Flax NNX for learnable parameters

Interface¤

class OperatorModule:
    def apply(
        self,
        data: PyTree,
        state: PyTree,
        metadata: dict | None,
        random_params: Any = None,
        stats: dict | None = None,
    ) -> tuple[PyTree, PyTree, dict | None]:
        """Transform data through the operator.

        Args:
            data: Input data (typically dict of arrays)
            state: Per-element state
            metadata: Optional element metadata
            random_params: Random parameters for stochastic ops
            stats: Optional statistics dictionary

        Returns:
            Tuple of (transformed_data, updated_state, updated_metadata)
        """

Usage Pattern¤

from diffbio.operators.alignment import SmoothSmithWaterman, SmithWatermanConfig
from diffbio.operators.alignment import create_dna_scoring_matrix

# Create operator
config = SmithWatermanConfig(temperature=1.0)
scoring = create_dna_scoring_matrix(match=2.0, mismatch=-1.0)
operator = SmoothSmithWaterman(config, scoring_matrix=scoring)

# Prepare data
data = {"seq1": seq1_tensor, "seq2": seq2_tensor}
state = {}
metadata = None

# Apply operator
result_data, state, metadata = operator.apply(data, state, metadata)

OperatorConfig¤

Configuration base class for operators from datarax.core.config.OperatorConfig:

from dataclasses import dataclass
from datarax.core.config import OperatorConfig

@dataclass(frozen=True)
class MyOperatorConfig(OperatorConfig):
    """Configuration for MyOperator.

    Attributes:
        my_param: Description of parameter.
    """
    my_param: float = 1.0

DiffBio Operator Hierarchy¤

graph TB
    A["datarax.core.operator.OperatorModule"] --> T["TemperatureOperator"]
    T --> C["SmoothSmithWaterman"]
    T --> D["DifferentiablePileup"]
    T --> Q["DifferentiableQualityFilter"]
    A --> E["VariantClassifier"]
    A --> F["VariantCallingPipeline"]

    S["diffbio.core.soft_ops"] -.->|"used by"| T

    style A fill:#ede9fe,stroke:#7c3aed,color:#4c1d95
    style T fill:#fef3c7,stroke:#d97706,color:#92400e
    style S fill:#d1fae5,stroke:#059669,color:#065f46
    style C fill:#e0e7ff,stroke:#4338ca,color:#312e81
    style D fill:#e0e7ff,stroke:#4338ca,color:#312e81
    style Q fill:#e0e7ff,stroke:#4338ca,color:#312e81
    style E fill:#e0e7ff,stroke:#4338ca,color:#312e81
    style F fill:#e0e7ff,stroke:#4338ca,color:#312e81

The soft_ops module provides the differentiable primitives (sorting, comparisons, selection) used by TemperatureOperator and its subclasses. See the Soft Operations API reference for the full list.

Learnable Parameters¤

DiffBio uses Flax NNX's nnx.Param for learnable parameters:

from flax import nnx

class MyOperator(OperatorModule):
    def __init__(self, config, rngs):
        super().__init__(config, rngs=rngs)

        # Learnable parameters
        self.temperature = nnx.Param(jnp.array(1.0))
        self.threshold = nnx.Param(jnp.array(20.0))

Accessing Parameters¤

# Get all parameters
params = nnx.state(operator, nnx.Param)

# Access specific parameter value
value = operator.temperature[...]

# Update parameter
operator.temperature[...] = new_value

Gradient Computation¤

import jax

def loss_fn(operator, data):
    result, _, _ = operator.apply(data, {}, None)
    return result["score"]

# Compute gradients w.r.t. parameters
grads = jax.grad(loss_fn)(operator, data)

Graph Utilities¤

compute_pairwise_distances¤

diffbio.core.graph_utils.compute_pairwise_distances ¤

compute_pairwise_distances(
    features: Array, metric: str = "euclidean"
) -> Array

Compute pairwise distance matrix between all samples.

Parameters:

Name Type Description Default
features Array

Input feature matrix of shape (n_samples, n_features).

required
metric str

Distance metric, either "euclidean" or "cosine".

'euclidean'

Returns:

Type Description
Array

Distance matrix of shape (n_samples, n_samples) where entry

Array

(i, j) is the distance from sample i to sample j.

compute_knn_graph¤

diffbio.core.graph_utils.compute_knn_graph ¤

compute_knn_graph(
    distances: Array, k: int
) -> tuple[Array, Array]

Build a k-nearest-neighbor graph from a dense distance matrix.

For each node the k closest neighbours (by distance) are selected. Self-connections are assumed to already be masked out by setting the diagonal to a large value before calling this function.

Parameters:

Name Type Description Default
distances Array

Dense distance matrix of shape (n, n). The diagonal should contain large sentinel values (e.g. DISTANCE_MASK_SENTINEL) so that self-loops are never selected.

required
k int

Number of nearest neighbours per node. Clipped to n - 1 when larger than the number of samples minus one.

required

Returns:

Type Description
Array

A tuple (edge_indices, edge_weights) where:

Array
  • edge_indices has shape (n * k_eff, 2) with each row [source, target].
tuple[Array, Array]
  • edge_weights has shape (n * k_eff,) containing the corresponding distances.
tuple[Array, Array]

k_eff = min(k, n - 1).

compute_fuzzy_membership¤

diffbio.core.graph_utils.compute_fuzzy_membership ¤

compute_fuzzy_membership(
    distances: Array, k: int, softness: float = 0.1
) -> Array

Compute fuzzy set membership using a Gaussian kernel with local bandwidth.

The bandwidth (sigma) for each sample is set to the distance to its k-th nearest neighbour, making the kernel adapt to local density. The diagonal of the output is forced to zero (no self-similarity).

Parameters:

Name Type Description Default
distances Array

Dense distance matrix of shape (n, n). The diagonal should contain large sentinel values so that self-distances are excluded from the bandwidth computation.

required
k int

Number of neighbours used to determine the local bandwidth. Clipped to n - 1 when larger.

required

Returns:

Type Description
Array

Fuzzy membership matrix of shape (n, n) with values in [0, 1].

symmetrize_graph¤

diffbio.core.graph_utils.symmetrize_graph ¤

symmetrize_graph(adjacency: Array) -> Array

Symmetrize a directed adjacency matrix via fuzzy set union.

Applies the probabilistic (fuzzy) union: p_sym = p + p^T - p * p^T

This ensures the output is symmetric and, when inputs are in [0, 1], the outputs remain in [0, 1].

Parameters:

Name Type Description Default
adjacency Array

Directed adjacency / membership matrix of shape (n, n).

required

Returns:

Type Description
Array

Symmetric adjacency matrix of shape (n, n).

GNN Components¤

GraphAttentionLayer¤

diffbio.core.gnn_components.GraphAttentionLayer ¤

GraphAttentionLayer(
    in_features: int,
    out_features: int,
    num_heads: int,
    edge_features: int,
    dropout_rate: float,
    *,
    rngs: Rngs,
)

Bases: Module

Multi-head graph attention layer for message passing.

Computes attention-weighted message aggregation over graph edges. Each attention head independently computes query/key/value projections, adds an edge-feature bias to attention scores, normalizes via segment-softmax, and aggregates weighted values per target node.

Parameters:

Name Type Description Default
in_features int

Input node feature dimension.

required
out_features int

Output feature dimension (must be divisible by num_heads).

required
num_heads int

Number of parallel attention heads.

required
edge_features int

Edge feature dimension.

required
dropout_rate float

Dropout rate applied to attention weights.

required
rngs Rngs

Flax NNX random number generators.

required

Parameters:

Name Type Description Default
in_features int

Input feature dimension.

required
out_features int

Output feature dimension.

required
num_heads int

Number of attention heads.

required
edge_features int

Edge feature dimension.

required
dropout_rate float

Dropout rate.

required
rngs Rngs

Random number generators.

required

__call__ ¤

__call__(
    node_features: Float[Array, "n_nodes in_features"],
    edge_index: Int[Array, "2 n_edges"],
    edge_features: Float[Array, "n_edges edge_features"],
    *,
    deterministic: bool = True,
) -> Float[Array, "n_nodes out_features"]

Run one graph-attention update step.

Parameters:

Name Type Description Default
node_features Float[Array, 'n_nodes in_features']

Node feature matrix of shape (n_nodes, in_features).

required
edge_index Int[Array, '2 n_edges']

Edge indices (source, target) of shape (2, n_edges).

required
edge_features Float[Array, 'n_edges edge_features']

Edge feature matrix of shape (n_edges, edge_features).

required
deterministic bool

Whether to disable stochastic dropout.

True

Returns:

Type Description
Float[Array, 'n_nodes out_features']

Updated node features of shape (n_nodes, out_features).

GraphAttentionBlock¤

diffbio.core.gnn_components.GraphAttentionBlock ¤

GraphAttentionBlock(
    hidden_dim: int,
    num_heads: int,
    edge_features: int,
    dropout_rate: float,
    *,
    rngs: Rngs,
)

Bases: Module

Full GNN block: graph attention + LayerNorm + residual + feedforward.

Combines a :class:GraphAttentionLayer with pre-norm residual connections and a two-layer feedforward network (4x expansion), following the standard Transformer block pattern adapted for graphs.

Architecture::

x -> GraphAttentionLayer -> + -> LayerNorm -> FFN -> + -> LayerNorm -> out
|___________________________|              |_________|

Parameters:

Name Type Description Default
hidden_dim int

Hidden dimension (both input and output).

required
num_heads int

Number of attention heads.

required
edge_features int

Edge feature dimension.

required
dropout_rate float

Dropout rate.

required
rngs Rngs

Flax NNX random number generators.

required

Parameters:

Name Type Description Default
hidden_dim int

Hidden dimension.

required
num_heads int

Number of attention heads.

required
edge_features int

Edge feature dimension.

required
dropout_rate float

Dropout rate.

required
rngs Rngs

Random number generators.

required

__call__ ¤

__call__(
    node_features: Float[Array, "n_nodes hidden_dim"],
    edge_index: Int[Array, "2 n_edges"],
    edge_features: Float[Array, "n_edges edge_features"],
    *,
    deterministic: bool = True,
) -> Float[Array, "n_nodes hidden_dim"]

Apply the GNN block.

Parameters:

Name Type Description Default
node_features Float[Array, 'n_nodes hidden_dim']

Node features of shape (n_nodes, hidden_dim).

required
edge_index Int[Array, '2 n_edges']

Edge indices of shape (2, n_edges).

required
edge_features Float[Array, 'n_edges edge_features']

Edge features of shape (n_edges, edge_features).

required
deterministic bool

If True, disable dropout.

True

Returns:

Type Description
Float[Array, 'n_nodes hidden_dim']

Updated node features of shape (n_nodes, hidden_dim).

GATv2Layer¤

diffbio.core.gnn_components.GATv2Layer ¤

GATv2Layer(
    in_features: int,
    out_features: int,
    num_heads: int,
    edge_features: int,
    dropout_rate: float,
    negative_slope: float = 0.2,
    *,
    rngs: Rngs,
)

Bases: Module

GATv2 multi-head graph attention layer.

Unlike the original GAT (GraphAttentionLayer), GATv2 applies LeakyReLU before computing the attention scalar, which makes the attention function strictly more expressive (it can represent any monotonic scoring function over concatenated source/target features).

GATv2 attention::

e_{ij} = a^T * LeakyReLU(W_l * h_i + W_r * h_j + edge_bias)

This is the key difference from GAT, where the nonlinearity is applied after the attention dot product.

Reference: Brody, Alon, Yahav. "How Attentive are Graph Attention Networks?" (ICLR 2022).

Parameters:

Name Type Description Default
in_features int

Input node feature dimension.

required
out_features int

Output feature dimension (must be divisible by num_heads).

required
num_heads int

Number of parallel attention heads.

required
edge_features int

Edge feature dimension.

required
dropout_rate float

Dropout rate applied to attention weights.

required
negative_slope float

Negative slope for LeakyReLU (default 0.2).

0.2
rngs Rngs

Flax NNX random number generators.

required

Parameters:

Name Type Description Default
in_features int

Input feature dimension.

required
out_features int

Output feature dimension.

required
num_heads int

Number of attention heads.

required
edge_features int

Edge feature dimension.

required
dropout_rate float

Dropout rate.

required
negative_slope float

Negative slope for LeakyReLU.

0.2
rngs Rngs

Random number generators.

required

__call__ ¤

__call__(
    node_features: Float[Array, "n_nodes in_features"],
    edge_index: Int[Array, "2 n_edges"],
    edge_features: Float[Array, "n_edges edge_features"],
    *,
    deterministic: bool = True,
) -> Float[Array, "n_nodes out_features"]

Run one GATv2 attention update step.

Parameters:

Name Type Description Default
node_features Float[Array, 'n_nodes in_features']

Node feature matrix (n_nodes, in_features).

required
edge_index Int[Array, '2 n_edges']

Edge indices (source, target) of shape (2, n_edges).

required
edge_features Float[Array, 'n_edges edge_features']

Edge feature matrix (n_edges, edge_features).

required
deterministic bool

Whether to disable stochastic dropout.

True

Returns:

Type Description
Float[Array, 'n_nodes out_features']

Updated node features of shape (n_nodes, out_features).

GATv2Block¤

diffbio.core.gnn_components.GATv2Block ¤

GATv2Block(
    hidden_dim: int,
    num_heads: int,
    edge_features: int,
    dropout_rate: float,
    negative_slope: float = 0.2,
    *,
    rngs: Rngs,
)

Bases: Module

Full GNN block using GATv2 attention + LayerNorm + residual + FFN.

Architecture::

x -> GATv2Layer -> + -> LayerNorm -> FFN -> + -> LayerNorm -> out
|__________________|              |_________|

Parameters:

Name Type Description Default
hidden_dim int

Hidden dimension (both input and output).

required
num_heads int

Number of attention heads.

required
edge_features int

Edge feature dimension.

required
dropout_rate float

Dropout rate.

required
negative_slope float

Negative slope for LeakyReLU in GATv2 attention.

0.2
rngs Rngs

Flax NNX random number generators.

required

Parameters:

Name Type Description Default
hidden_dim int

Hidden dimension.

required
num_heads int

Number of attention heads.

required
edge_features int

Edge feature dimension.

required
dropout_rate float

Dropout rate.

required
negative_slope float

Negative slope for LeakyReLU.

0.2
rngs Rngs

Random number generators.

required

__call__ ¤

__call__(
    node_features: Float[Array, "n_nodes hidden_dim"],
    edge_index: Int[Array, "2 n_edges"],
    edge_features: Float[Array, "n_edges edge_features"],
    *,
    deterministic: bool = True,
) -> Float[Array, "n_nodes hidden_dim"]

Apply the GATv2 block.

Parameters:

Name Type Description Default
node_features Float[Array, 'n_nodes hidden_dim']

Node features (n_nodes, hidden_dim).

required
edge_index Int[Array, '2 n_edges']

Edge indices (2, n_edges).

required
edge_features Float[Array, 'n_edges edge_features']

Edge features (n_edges, edge_features).

required
deterministic bool

If True, disable dropout.

True

Returns:

Type Description
Float[Array, 'n_nodes hidden_dim']

Updated node features (n_nodes, hidden_dim).

Optimal Transport¤

SinkhornLayer¤

diffbio.core.optimal_transport.SinkhornLayer ¤

SinkhornLayer(
    epsilon: float, num_iters: int, *, rngs: Rngs
)

Bases: Module

Sinkhorn optimal transport layer (log-domain).

Computes the entropy-regularised optimal transport plan between two discrete marginal distributions a and b given a cost matrix C, by solving::

min_{P >= 0}  <P, C> - epsilon * H(P)
s.t.  P @ 1 = a,   P^T @ 1 = b

The algorithm runs in log-domain for numerical stability::

f, g = 0, 0
for _ in range(num_iters):
    f = epsilon * log(a) - epsilon * logsumexp((-C + g) / epsilon, axis=1)
    g = epsilon * log(b) - epsilon * logsumexp((-C + f) / epsilon, axis=0)
P = exp((f[:, None] + g[None, :] - C) / epsilon)

Parameters:

Name Type Description Default
epsilon float

Entropy regularisation strength (larger = smoother plan).

required
num_iters int

Number of Sinkhorn iterations.

required
rngs Rngs

Flax NNX random number generators (unused, kept for API consistency).

required

Parameters:

Name Type Description Default
epsilon float

Regularisation strength.

required
num_iters int

Number of Sinkhorn iterations.

required
rngs Rngs

Random number generators (for API consistency).

required

__call__ ¤

__call__(
    cost: Float[Array, "n m"],
    a: Float[Array, " n"],
    b: Float[Array, " m"],
) -> Float[Array, "n m"]

Compute the optimal transport plan.

Parameters:

Name Type Description Default
cost Float[Array, 'n m']

Cost matrix of shape (n, m).

required
a Float[Array, ' n']

Source marginal distribution of shape (n,), must sum to 1.

required
b Float[Array, ' m']

Target marginal distribution of shape (m,), must sum to 1.

required

Returns:

Type Description
Float[Array, 'n m']

Transport plan of shape (n, m) satisfying (approximately)

Float[Array, 'n m']

P @ 1 = a and P^T @ 1 = b.

Type Annotations¤

DiffBio uses jaxtyping for array type annotations:

from jaxtyping import Array, Float, Int, PyTree

# Float array with shape annotations
def align(
    self,
    seq1: Float[Array, "len1 alphabet"],
    seq2: Float[Array, "len2 alphabet"],
) -> Float[Array, ""]:
    ...

Common type patterns:

Annotation Description
Float[Array, ""] Scalar float
Float[Array, "n"] 1D float array
Float[Array, "n m"] 2D float array
Int[Array, "n"] 1D integer array
PyTree JAX pytree (dict, tuple, etc.)