Skip to content

Assembly Operators API¤

Differentiable operators for genome assembly using graph neural networks and VAE-based binning.

GNNAssemblyNavigator¤

diffbio.operators.assembly.gnn_assembly.GNNAssemblyNavigator ¤

GNNAssemblyNavigator(
    config: GNNAssemblyNavigatorConfig,
    *,
    rngs: Rngs | None = None,
    name: str | None = None,
)

Bases: GraphOperator

Graph Neural Network for assembly graph traversal.

This operator uses message passing with graph attention to update node embeddings and predict edge traversal probabilities for differentiable assembly.

Algorithm: 1. Project input node features to hidden dimension 2. Apply multiple GNN layers with graph attention 3. Compute edge scores from source/target node embeddings 4. Apply sigmoid for traversal probabilities 5. Compute path confidence from edge scores

Inherits from GraphOperator to get:

  • scatter_aggregate() for message aggregation utilities
  • global_pool() for graph-level pooling

Uses temperature-controlled smoothing: - _temperature property for temperature-controlled sigmoid

Parameters:

Name Type Description Default
config GNNAssemblyNavigatorConfig

GNNAssemblyNavigatorConfig with model parameters.

required
rngs Rngs | None

Flax NNX random number generators.

None
name str | None

Optional operator name.

None
Example
config = GNNAssemblyNavigatorConfig(hidden_dim=128)
navigator = GNNAssemblyNavigator(config, rngs=nnx.Rngs(42))
data = {"node_features": nodes, "edge_index": edges, "edge_features": edge_attr}
result, state, meta = navigator.apply(data, {}, None)

Parameters:

Name Type Description Default
config GNNAssemblyNavigatorConfig

Navigator configuration.

required
rngs Rngs | None

Random number generators for initialization.

None
name str | None

Optional operator name.

None

apply ¤

apply(
    data: PyTree,
    state: PyTree,
    metadata: dict[str, Any] | None,
    random_params: Any = None,
    stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]

Apply GNN assembly navigation.

Parameters:

Name Type Description Default
data PyTree

Dictionary containing: - "node_features": Node features (n_nodes, node_features) - "edge_index": Edge indices (2, n_edges) - "edge_features": Edge features (n_edges, edge_features)

required
state PyTree

Element state (passed through unchanged)

required
metadata dict[str, Any] | None

Element metadata (passed through unchanged)

required
random_params Any

Not used

None
stats dict[str, Any] | None

Not used

None

Returns:

Type Description
tuple[PyTree, PyTree, dict[str, Any] | None]

Tuple of (transformed_data, state, metadata): - transformed_data contains:

- "node_features": Original node features
- "edge_index": Original edge indices
- "edge_features": Original edge features
- "node_embeddings": Updated node embeddings
- "edge_scores": Scores for each edge
- "traversal_probs": Sigmoid probabilities for traversal
- "path_confidence": Confidence score for paths
  • state is passed through unchanged
  • metadata is passed through unchanged

GNNAssemblyNavigatorConfig¤

diffbio.operators.assembly.gnn_assembly.GNNAssemblyNavigatorConfig dataclass ¤

GNNAssemblyNavigatorConfig(
    temperature: float = 1.0,
    learnable_temperature: bool = False,
    node_features: int = 64,
    hidden_dim: int = 128,
    num_layers: int = 3,
    num_heads: int = 4,
    edge_features: int = 8,
    dropout_rate: float = 0.1,
)

Bases: TemperatureConfig

Configuration for GNNAssemblyNavigator.

Attributes:

Name Type Description
node_features int

Dimension of input node features.

hidden_dim int

Hidden dimension for GNN layers.

num_layers int

Number of GNN layers.

num_heads int

Number of attention heads.

edge_features int

Dimension of edge features.

dropout_rate float

Dropout rate for regularization.

temperature float

Temperature for softmax operations.

learnable_temperature class-attribute instance-attribute ¤

learnable_temperature: bool = False

DifferentiableMetagenomicBinner¤

diffbio.operators.assembly.metagenomic_binning.DifferentiableMetagenomicBinner ¤

DifferentiableMetagenomicBinner(
    config: MetagenomicBinnerConfig,
    *,
    rngs: Rngs,
    name: str | None = None,
)

Bases: TemperatureOperator, EncoderDecoderOperator

VAMB-style differentiable metagenomic binning.

This operator implements a Variational Autoencoder for metagenomic binning, encoding tetranucleotide frequencies (TNF) and abundance profiles into a shared latent space where contigs from the same genome cluster together.

The approach is fully differentiable, enabling: - End-to-end optimization with downstream tasks - Soft cluster assignments via temperature-controlled softmax - Integration with neural abundance estimation

Input data structure
  • tnf: Float[Array, "n_contigs n_tnf"] - Tetranucleotide frequencies
  • abundance: Float[Array, "n_contigs n_samples"] - Sample abundances

Output data structure (adds): - latent_z: Float[Array, "n_contigs latent_dim"] - Latent representations - latent_mu: Float[Array, "n_contigs latent_dim"] - Latent means - latent_logvar: Float[Array, "n_contigs latent_dim"] - Latent log variance - cluster_assignments: Float[Array, "n_contigs n_clusters"] - Soft bins - reconstructed_tnf: Float[Array, "n_contigs n_tnf"] - Reconstructed TNF - reconstructed_abundance: Float[Array, "n_contigs n_samples"] - Recon. abundance

Example
config = MetagenomicBinnerConfig(n_abundance_features=5, n_clusters=50)
binner = DifferentiableMetagenomicBinner(config, rngs=nnx.Rngs(42))
result, state, meta = binner.apply(data, {}, None)
bins = result["cluster_assignments"].argmax(axis=-1)

Parameters:

Name Type Description Default
config MetagenomicBinnerConfig

Binner configuration.

required
rngs Rngs

Random number generators.

required
name str | None

Optional name for the operator.

None

apply ¤

apply(
    data: dict[str, Array],
    state: dict[str, Any],
    metadata: dict[str, Any] | None,
    random_params: Any = None,
    stats: dict[str, Any] | None = None,
) -> tuple[
    dict[str, Array], dict[str, Any], dict[str, Any] | None
]

Apply metagenomic binning.

Parameters:

Name Type Description Default
data dict[str, Array]

Input data containing: - tnf: Float[Array, "n_contigs n_tnf"] - abundance: Float[Array, "n_contigs n_samples"]

required
state dict[str, Any]

Element state (passed through).

required
metadata dict[str, Any] | None

Element metadata (passed through).

required
random_params Any

Random parameters.

None
stats dict[str, Any] | None

Optional statistics dict.

None

Returns:

Type Description
tuple[dict[str, Array], dict[str, Any], dict[str, Any] | None]

Tuple of (output_data, state, metadata).

encode ¤

encode(
    x: Float[Array, "batch input_dim"],
) -> tuple[
    Float[Array, "batch latent"],
    Float[Array, "batch latent"],
]

Encode input to latent distribution.

Parameters:

Name Type Description Default
x Float[Array, 'batch input_dim']

Concatenated TNF and abundance features.

required

Returns:

Type Description
tuple[Float[Array, 'batch latent'], Float[Array, 'batch latent']]

Tuple of (mu, logvar).

decode ¤

decode(
    z: Float[Array, "batch latent"],
) -> tuple[
    Float[Array, "batch n_tnf"],
    Float[Array, "batch n_abundance"],
]

Decode latent to reconstructed features.

Parameters:

Name Type Description Default
z Float[Array, 'batch latent']

Latent representation.

required

Returns:

Type Description
tuple[Float[Array, 'batch n_tnf'], Float[Array, 'batch n_abundance']]

Tuple of (tnf, abundance).

soft_cluster ¤

soft_cluster(
    z: Float[Array, "batch latent"],
) -> Float[Array, "batch n_clusters"]

Compute soft cluster assignments.

Parameters:

Name Type Description Default
z Float[Array, 'batch latent']

Latent representations.

required

Returns:

Type Description
Float[Array, 'batch n_clusters']

Soft cluster assignment probabilities.

MetagenomicBinnerConfig¤

diffbio.operators.assembly.metagenomic_binning.MetagenomicBinnerConfig dataclass ¤

MetagenomicBinnerConfig(
    temperature: float = DEFAULT_TEMPERATURE,
    learnable_temperature: bool = False,
    n_tnf_features: int = 136,
    n_abundance_features: int = 10,
    latent_dim: int = 32,
    hidden_dims: tuple[int, ...] = (512, 256),
    dropout_rate: float = 0.2,
    beta: float = 1.0,
    n_clusters: int = 100,
)

Bases: TemperatureConfig

Configuration for metagenomic binning VAE.

Attributes:

Name Type Description
n_tnf_features int

Number of tetranucleotide frequency features (default 136).

n_abundance_features int

Number of sample abundance features.

latent_dim int

Dimension of the latent space.

hidden_dims tuple[int, ...]

tuple of hidden layer dimensions for encoder/decoder.

dropout_rate float

Dropout rate for regularization.

beta float

KL divergence weight (beta-VAE).

n_clusters int

Number of clusters for soft binning.

temperature class-attribute instance-attribute ¤

temperature: float = DEFAULT_TEMPERATURE

learnable_temperature class-attribute instance-attribute ¤

learnable_temperature: bool = False

create_metagenomic_binner¤

diffbio.operators.assembly.metagenomic_binning.create_metagenomic_binner ¤

create_metagenomic_binner(
    n_abundance_features: int = 10,
    n_clusters: int = 100,
    latent_dim: int = 32,
    hidden_dims: tuple[int, ...] | None = None,
    seed: int = 42,
) -> DifferentiableMetagenomicBinner

Factory function to create a metagenomic binner.

Parameters:

Name Type Description Default
n_abundance_features int

Number of sample abundance features.

10
n_clusters int

Number of clusters/bins.

100
latent_dim int

Dimension of latent space.

32
hidden_dims tuple[int, ...] | None

Hidden layer dimensions.

None
seed int

Random seed.

42

Returns:

Type Description
DifferentiableMetagenomicBinner

Configured DifferentiableMetagenomicBinner instance.

Usage Examples¤

Assembly Graph Navigation¤

from flax import nnx
from diffbio.operators.assembly import GNNAssemblyNavigator, GNNAssemblyNavigatorConfig

config = GNNAssemblyNavigatorConfig(
    node_features=64,
    edge_features=32,
    hidden_dim=128,
    num_layers=3,
)
navigator = GNNAssemblyNavigator(config, rngs=nnx.Rngs(42))

data = {
    "node_features": node_feats,   # (n_nodes, node_dim)
    "edge_index": edges,           # (2, n_edges)
    "edge_features": edge_feats,   # (n_edges, edge_dim)
}
result, _, _ = navigator.apply(data, {}, None)
next_node_probs = result["next_node_probs"]

De Bruijn Graph Construction¤

import jax.numpy as jnp

# Create k-mer node features
k = 31
kmers = extract_kmers(sequences, k)
kmer_embeddings = embed_kmers(kmers)
coverage = compute_kmer_coverage(kmers)

node_features = jnp.concatenate([
    kmer_embeddings,
    coverage[:, None],
], axis=-1)

# Create edges for k-1 overlaps
edge_index = find_overlapping_kmers(kmers, k-1)

Metagenomic Binning¤

from flax import nnx
from diffbio.operators.assembly import (
    DifferentiableMetagenomicBinner,
    MetagenomicBinnerConfig,
    create_metagenomic_binner,
)

# Using config
config = MetagenomicBinnerConfig(
    n_tnf_features=136,
    n_abundance_features=10,
    latent_dim=32,
    n_clusters=100,
)
binner = DifferentiableMetagenomicBinner(config, rngs=nnx.Rngs(42))

# Or using factory function
binner = create_metagenomic_binner(
    n_abundance_features=10,
    n_clusters=100,
    latent_dim=32,
)

# Apply binning
data = {
    "tnf": tnf_features,      # (n_contigs, 136)
    "abundance": abundances,   # (n_contigs, n_samples)
}
result, _, _ = binner.apply(data, {}, None)

# Get cluster assignments
bins = result["cluster_assignments"].argmax(axis=-1)
latent = result["latent_z"]