Assembly Operators API¤
Differentiable operators for genome assembly using graph neural networks and VAE-based binning.
GNNAssemblyNavigator¤
diffbio.operators.assembly.gnn_assembly.GNNAssemblyNavigator
¤
GNNAssemblyNavigator(
config: GNNAssemblyNavigatorConfig,
*,
rngs: Rngs | None = None,
name: str | None = None,
)
Bases: GraphOperator
Graph Neural Network for assembly graph traversal.
This operator uses message passing with graph attention to update node embeddings and predict edge traversal probabilities for differentiable assembly.
Algorithm: 1. Project input node features to hidden dimension 2. Apply multiple GNN layers with graph attention 3. Compute edge scores from source/target node embeddings 4. Apply sigmoid for traversal probabilities 5. Compute path confidence from edge scores
Inherits from GraphOperator to get:
- scatter_aggregate() for message aggregation utilities
- global_pool() for graph-level pooling
Uses temperature-controlled smoothing: - _temperature property for temperature-controlled sigmoid
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
GNNAssemblyNavigatorConfig
|
GNNAssemblyNavigatorConfig with model parameters. |
required |
rngs
|
Rngs | None
|
Flax NNX random number generators. |
None
|
name
|
str | None
|
Optional operator name. |
None
|
Example
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
GNNAssemblyNavigatorConfig
|
Navigator configuration. |
required |
rngs
|
Rngs | None
|
Random number generators for initialization. |
None
|
name
|
str | None
|
Optional operator name. |
None
|
apply
¤
apply(
data: PyTree,
state: PyTree,
metadata: dict[str, Any] | None,
random_params: Any = None,
stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]
Apply GNN assembly navigation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
PyTree
|
Dictionary containing: - "node_features": Node features (n_nodes, node_features) - "edge_index": Edge indices (2, n_edges) - "edge_features": Edge features (n_edges, edge_features) |
required |
state
|
PyTree
|
Element state (passed through unchanged) |
required |
metadata
|
dict[str, Any] | None
|
Element metadata (passed through unchanged) |
required |
random_params
|
Any
|
Not used |
None
|
stats
|
dict[str, Any] | None
|
Not used |
None
|
Returns:
| Type | Description |
|---|---|
tuple[PyTree, PyTree, dict[str, Any] | None]
|
Tuple of (transformed_data, state, metadata): - transformed_data contains:
|
GNNAssemblyNavigatorConfig¤
diffbio.operators.assembly.gnn_assembly.GNNAssemblyNavigatorConfig
dataclass
¤
GNNAssemblyNavigatorConfig(
temperature: float = 1.0,
learnable_temperature: bool = False,
node_features: int = 64,
hidden_dim: int = 128,
num_layers: int = 3,
num_heads: int = 4,
edge_features: int = 8,
dropout_rate: float = 0.1,
)
Bases: TemperatureConfig
Configuration for GNNAssemblyNavigator.
Attributes:
| Name | Type | Description |
|---|---|---|
node_features |
int
|
Dimension of input node features. |
hidden_dim |
int
|
Hidden dimension for GNN layers. |
num_layers |
int
|
Number of GNN layers. |
num_heads |
int
|
Number of attention heads. |
edge_features |
int
|
Dimension of edge features. |
dropout_rate |
float
|
Dropout rate for regularization. |
temperature |
float
|
Temperature for softmax operations. |
DifferentiableMetagenomicBinner¤
diffbio.operators.assembly.metagenomic_binning.DifferentiableMetagenomicBinner
¤
DifferentiableMetagenomicBinner(
config: MetagenomicBinnerConfig,
*,
rngs: Rngs,
name: str | None = None,
)
Bases: TemperatureOperator, EncoderDecoderOperator
VAMB-style differentiable metagenomic binning.
This operator implements a Variational Autoencoder for metagenomic binning, encoding tetranucleotide frequencies (TNF) and abundance profiles into a shared latent space where contigs from the same genome cluster together.
The approach is fully differentiable, enabling: - End-to-end optimization with downstream tasks - Soft cluster assignments via temperature-controlled softmax - Integration with neural abundance estimation
Input data structure
- tnf: Float[Array, "n_contigs n_tnf"] - Tetranucleotide frequencies
- abundance: Float[Array, "n_contigs n_samples"] - Sample abundances
Output data structure (adds): - latent_z: Float[Array, "n_contigs latent_dim"] - Latent representations - latent_mu: Float[Array, "n_contigs latent_dim"] - Latent means - latent_logvar: Float[Array, "n_contigs latent_dim"] - Latent log variance - cluster_assignments: Float[Array, "n_contigs n_clusters"] - Soft bins - reconstructed_tnf: Float[Array, "n_contigs n_tnf"] - Reconstructed TNF - reconstructed_abundance: Float[Array, "n_contigs n_samples"] - Recon. abundance
Example
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
MetagenomicBinnerConfig
|
Binner configuration. |
required |
rngs
|
Rngs
|
Random number generators. |
required |
name
|
str | None
|
Optional name for the operator. |
None
|
apply
¤
apply(
data: dict[str, Array],
state: dict[str, Any],
metadata: dict[str, Any] | None,
random_params: Any = None,
stats: dict[str, Any] | None = None,
) -> tuple[
dict[str, Array], dict[str, Any], dict[str, Any] | None
]
Apply metagenomic binning.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
dict[str, Array]
|
Input data containing: - tnf: Float[Array, "n_contigs n_tnf"] - abundance: Float[Array, "n_contigs n_samples"] |
required |
state
|
dict[str, Any]
|
Element state (passed through). |
required |
metadata
|
dict[str, Any] | None
|
Element metadata (passed through). |
required |
random_params
|
Any
|
Random parameters. |
None
|
stats
|
dict[str, Any] | None
|
Optional statistics dict. |
None
|
Returns:
| Type | Description |
|---|---|
tuple[dict[str, Array], dict[str, Any], dict[str, Any] | None]
|
Tuple of (output_data, state, metadata). |
encode
¤
encode(
x: Float[Array, "batch input_dim"],
) -> tuple[
Float[Array, "batch latent"],
Float[Array, "batch latent"],
]
Encode input to latent distribution.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Float[Array, 'batch input_dim']
|
Concatenated TNF and abundance features. |
required |
Returns:
| Type | Description |
|---|---|
tuple[Float[Array, 'batch latent'], Float[Array, 'batch latent']]
|
Tuple of (mu, logvar). |
decode
¤
decode(
z: Float[Array, "batch latent"],
) -> tuple[
Float[Array, "batch n_tnf"],
Float[Array, "batch n_abundance"],
]
Decode latent to reconstructed features.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
z
|
Float[Array, 'batch latent']
|
Latent representation. |
required |
Returns:
| Type | Description |
|---|---|
tuple[Float[Array, 'batch n_tnf'], Float[Array, 'batch n_abundance']]
|
Tuple of (tnf, abundance). |
soft_cluster
¤
Compute soft cluster assignments.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
z
|
Float[Array, 'batch latent']
|
Latent representations. |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, 'batch n_clusters']
|
Soft cluster assignment probabilities. |
MetagenomicBinnerConfig¤
diffbio.operators.assembly.metagenomic_binning.MetagenomicBinnerConfig
dataclass
¤
MetagenomicBinnerConfig(
temperature: float = DEFAULT_TEMPERATURE,
learnable_temperature: bool = False,
n_tnf_features: int = 136,
n_abundance_features: int = 10,
latent_dim: int = 32,
hidden_dims: tuple[int, ...] = (512, 256),
dropout_rate: float = 0.2,
beta: float = 1.0,
n_clusters: int = 100,
)
Bases: TemperatureConfig
Configuration for metagenomic binning VAE.
Attributes:
| Name | Type | Description |
|---|---|---|
n_tnf_features |
int
|
Number of tetranucleotide frequency features (default 136). |
n_abundance_features |
int
|
Number of sample abundance features. |
latent_dim |
int
|
Dimension of the latent space. |
hidden_dims |
tuple[int, ...]
|
tuple of hidden layer dimensions for encoder/decoder. |
dropout_rate |
float
|
Dropout rate for regularization. |
beta |
float
|
KL divergence weight (beta-VAE). |
n_clusters |
int
|
Number of clusters for soft binning. |
create_metagenomic_binner¤
diffbio.operators.assembly.metagenomic_binning.create_metagenomic_binner
¤
create_metagenomic_binner(
n_abundance_features: int = 10,
n_clusters: int = 100,
latent_dim: int = 32,
hidden_dims: tuple[int, ...] | None = None,
seed: int = 42,
) -> DifferentiableMetagenomicBinner
Factory function to create a metagenomic binner.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_abundance_features
|
int
|
Number of sample abundance features. |
10
|
n_clusters
|
int
|
Number of clusters/bins. |
100
|
latent_dim
|
int
|
Dimension of latent space. |
32
|
hidden_dims
|
tuple[int, ...] | None
|
Hidden layer dimensions. |
None
|
seed
|
int
|
Random seed. |
42
|
Returns:
| Type | Description |
|---|---|
DifferentiableMetagenomicBinner
|
Configured DifferentiableMetagenomicBinner instance. |
Usage Examples¤
Assembly Graph Navigation¤
from flax import nnx
from diffbio.operators.assembly import GNNAssemblyNavigator, GNNAssemblyNavigatorConfig
config = GNNAssemblyNavigatorConfig(
node_features=64,
edge_features=32,
hidden_dim=128,
num_layers=3,
)
navigator = GNNAssemblyNavigator(config, rngs=nnx.Rngs(42))
data = {
"node_features": node_feats, # (n_nodes, node_dim)
"edge_index": edges, # (2, n_edges)
"edge_features": edge_feats, # (n_edges, edge_dim)
}
result, _, _ = navigator.apply(data, {}, None)
next_node_probs = result["next_node_probs"]
De Bruijn Graph Construction¤
import jax.numpy as jnp
# Create k-mer node features
k = 31
kmers = extract_kmers(sequences, k)
kmer_embeddings = embed_kmers(kmers)
coverage = compute_kmer_coverage(kmers)
node_features = jnp.concatenate([
kmer_embeddings,
coverage[:, None],
], axis=-1)
# Create edges for k-1 overlaps
edge_index = find_overlapping_kmers(kmers, k-1)
Metagenomic Binning¤
from flax import nnx
from diffbio.operators.assembly import (
DifferentiableMetagenomicBinner,
MetagenomicBinnerConfig,
create_metagenomic_binner,
)
# Using config
config = MetagenomicBinnerConfig(
n_tnf_features=136,
n_abundance_features=10,
latent_dim=32,
n_clusters=100,
)
binner = DifferentiableMetagenomicBinner(config, rngs=nnx.Rngs(42))
# Or using factory function
binner = create_metagenomic_binner(
n_abundance_features=10,
n_clusters=100,
latent_dim=32,
)
# Apply binning
data = {
"tnf": tnf_features, # (n_contigs, 136)
"abundance": abundances, # (n_contigs, n_samples)
}
result, _, _ = binner.apply(data, {}, None)
# Get cluster assignments
bins = result["cluster_assignments"].argmax(axis=-1)
latent = result["latent_z"]