Skip to content

Operators Overview¤

DiffBio provides a collection of differentiable operators for bioinformatics analysis. Each operator inherits from Datarax's OperatorModule for consistent interfaces and composability.

Foundation: Soft Operations¤

All DiffBio operators are built on a shared layer of differentiable primitives in diffbio.core.soft_ops. This module provides 79 smooth relaxations of discrete operations (sorting, comparisons, logical gates, selection) with 5 smoothness modes and multiple algorithmic backends.

Operators do not call JAX's hard jnp.max, jnp.argmax, or jnp.sort directly. Instead, they use soft_ops equivalents that produce well-defined gradients:

from diffbio.core import soft_ops

# Soft quality threshold (returns probability, not boolean)
is_good = soft_ops.greater(quality, 20.0, softness=1.0)

# Soft sorting for gene ranking
ranked = soft_ops.sort(expression, axis=0, softness=0.1)

# Soft top-k variant selection
values, indices = soft_ops.top_k(scores, k=10, softness=0.1)

See the Soft Operations concept guide for a detailed walkthrough, or the API reference for the full function list.


Available Operators¤

Alignment
Smith-Waterman
Pileup
Read Aggregation
Filter
Quality Control

Core Operators¤

Operator Description Status
DifferentiableQualityFilter Sigmoid-based soft quality filtering Implemented
DifferentiablePileup Soft pileup generation Implemented
SmoothSmithWaterman Differentiable local sequence alignment Implemented
VariantClassifier Neural variant classifier Implemented

Alignment Operators¤

Operator Description Status
SoftProgressiveMSA Differentiable multiple sequence alignment with guide tree Implemented
ProfileHMMSearch Profile Hidden Markov Model for sequence homology Implemented

Epigenomics Operators¤

Operator Description Status
DifferentiablePeakCaller CNN-based peak calling for ChIP-seq/ATAC-seq Implemented
ChromatinStateAnnotator HMM-based chromatin state classification Implemented

RNA-seq Operators¤

Operator Description Status
SplicingPSI Differentiable PSI calculation for alternative splicing Implemented
DifferentiableMotifDiscovery Learnable PWM-based motif discovery Implemented

Single-Cell Operators¤

Operator Description Status
SoftKMeansClustering Differentiable soft k-means with learnable centroids Implemented
DifferentiableHarmony Harmony-style batch correction Implemented
DifferentiableVelocity RNA velocity via neural ODEs Implemented
DifferentiableAmbientRemoval VAE-based ambient RNA decontamination Implemented

Preprocessing Operators¤

Operator Description Status
SoftAdapterRemoval Differentiable adapter trimming with soft alignment Implemented
DifferentiableDuplicateWeighting Probabilistic duplicate weighting Implemented
SoftErrorCorrection Neural network-based error correction Implemented

Normalization Operators¤

Operator Description Status
VAENormalizer scVI-style VAE for count normalization Implemented
DifferentiableUMAP Differentiable UMAP dimensionality reduction Implemented
SequenceEmbedding Learned sequence embeddings Implemented

Statistical Operators¤

Operator Description Status
DifferentiableHMM Forward algorithm with logsumexp stability Implemented
DifferentiableNBGLM Negative binomial GLM for differential expression Implemented
DifferentiableEMQuantifier Unrolled EM for transcript quantification Implemented

Assembly & Mapping Operators¤

Operator Description Status
GNNAssemblyNavigator GNN for assembly graph traversal Implemented
NeuralReadMapper Cross-attention based read mapping Implemented
DifferentiableMetagenomicBinner VAMB-style VAE for metagenomic binning Implemented

Multi-omics Operators¤

Operator Description Status
SpatialDeconvolution Cell type deconvolution for spatial transcriptomics Implemented
HiCContactAnalysis Chromatin contact analysis for Hi-C data Implemented
DifferentiableSpatialGeneDetector SpatialDE-style spatial gene detection Implemented

Variant Operators¤

Operator Description Status
CNNVariantClassifier CNN-based variant classification Implemented
DifferentiableCNVSegmentation Copy number variation segmentation Implemented
SoftVariantQualityFilter Base quality score recalibration Implemented
DeepVariantStylePileup Multi-channel pileup image generation for DeepVariant-style CNNs Implemented

Population Genetics Operators¤

Operator Description Status
DifferentiableAncestryEstimator Neural ADMIXTURE-style ancestry estimation Implemented

CRISPR Operators¤

Operator Description Status
DifferentiableCRISPRScorer DeepCRISPR-style guide RNA efficiency prediction Implemented

Metabolomics Operators¤

Operator Description Status
DifferentiableSpectralSimilarity MS2DeepScore-style Siamese network for MS/MS similarity Implemented

Protein Structure Operators¤

Operator Description Status
DifferentiableSecondaryStructure PyDSSP-style DSSP with continuous H-bond matrix Implemented

Foundation Model Operators¤

Operator Description Status
TransformerSequenceEncoder DNABERT/RNA-FM-style transformer for sequence embedding Implemented

RNA Structure Operators¤

Operator Description Status
DifferentiableRNAFold McCaskill-style partition function for base pair probabilities Implemented

Molecular Dynamics Operators¤

Operator Description Status
ForceFieldOperator Differentiable force field (LJ, Morse, Soft Sphere) using JAX-MD Implemented
MDIntegratorOperator Time integration for MD (velocity Verlet, Langevin) using JAX-MD Implemented

Drug Discovery Operators¤

Operator Description Status
MolecularPropertyPredictor ChemProp-style D-MPNN for molecular property prediction Implemented
ADMETPredictor Multi-task ADMET prediction (22 TDC endpoints) Implemented
DifferentiableMolecularFingerprint Neural graph fingerprints as alternative to ECFP/Morgan Implemented
CircularFingerprintOperator Differentiable ECFP/Morgan circular fingerprints Implemented
MACCSKeysOperator Differentiable MACCS 166 structural keys fingerprint Implemented
AttentiveFP Attention-based graph fingerprint with GRU (Xiong et al. 2019) Implemented
MolecularSimilarityOperator Differentiable Tanimoto/cosine/Dice similarity Implemented

Operator Interface¤

All DiffBio operators implement the Datarax OperatorModule interface:

class OperatorModule:
    def apply(
        self,
        data: PyTree,
        state: PyTree,
        metadata: dict | None,
        random_params: Any = None,
        stats: dict | None = None,
    ) -> tuple[PyTree, PyTree, dict | None]:
        """Transform data through the operator.

        Args:
            data: Input data as a PyTree (typically dict)
            state: Per-element state (passed through or modified)
            metadata: Optional metadata
            random_params: Random parameters for stochastic operators
            stats: Optional statistics dictionary

        Returns:
            Tuple of (transformed_data, updated_state, updated_metadata)
        """

Configuration Pattern¤

Each operator has a corresponding configuration dataclass:

from dataclasses import dataclass
from datarax.core.config import OperatorConfig

@dataclass(frozen=True)
class MyOperatorConfig(OperatorConfig):
    """Configuration for MyOperator."""
    temperature: float = 1.0

Example Usage¤

from diffbio.operators import DifferentiableQualityFilter, QualityFilterConfig

# 1. Create configuration
config = QualityFilterConfig(initial_threshold=20.0)

# 2. Instantiate operator
operator = DifferentiableQualityFilter(config)

# 3. Prepare data
data = {
    "sequence": sequence_tensor,
    "quality_scores": quality_tensor,
}

# 4. Apply operator
result_data, state, metadata = operator.apply(data, {}, None)

Composing Operators¤

Operators can be composed into pipelines using Datarax's composition utilities:

Sequential Composition¤

from datarax.operators import CompositeOperatorModule

# Chain operators sequentially
pipeline = CompositeOperatorModule([
    quality_filter,
    aligner,
    pileup_generator,
])

# Apply entire pipeline
result, state, meta = pipeline.apply(data, {}, None)

Manual Composition¤

def my_pipeline(data):
    # Step 1: Quality filtering
    data, state, meta = quality_filter.apply(data, {}, None)

    # Step 2: Alignment
    data, state, meta = aligner.apply(data, state, meta)

    # Step 3: Pileup
    data, state, meta = pileup_op.apply(data, state, meta)

    return data

Learnable Parameters¤

DiffBio operators use Flax NNX for parameter management:

Accessing Parameters¤

from flax import nnx

# Get all parameters
params = nnx.state(operator, nnx.Param)
print(params)

# Access specific parameter
print(operator.threshold[...])  # Array value

Updating Parameters¤

# Manual update
operator.threshold[...] = new_value

# Gradient-based update
import optax

optimizer = optax.adam(learning_rate=0.001)
opt_state = optimizer.init(params)

def update_step(params, grads, opt_state):
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state

JAX Transformations¤

All operators are compatible with JAX transformations:

JIT Compilation¤

import jax

@jax.jit
def apply_operator(data):
    result, _, _ = operator.apply(data, {}, None)
    return result

# Fast execution after first compile
result = apply_operator(data)

Vectorization¤

# Process batch of inputs
def single_apply(single_data):
    result, _, _ = operator.apply(single_data, {}, None)
    return result

batch_apply = jax.vmap(single_apply)
batch_results = batch_apply(batch_data)

Gradient Computation¤

def loss_fn(data):
    result, _, _ = operator.apply(data, {}, None)
    return result['score'].mean()

# Compute gradients w.r.t. operator parameters
grad_fn = jax.grad(loss_fn)
grads = grad_fn(data)

Best Practices¤

1. Use Configuration Objects¤

# Good: Use config dataclass
config = SmithWatermanConfig(
    temperature=1.0,
    gap_open=-10.0,
)
aligner = SmoothSmithWaterman(config, scoring_matrix=scoring)

# Avoid: Hardcoded values scattered in code

2. Preserve Input Keys¤

When implementing custom operators, preserve input data keys:

def apply(self, data, state, metadata, ...):
    result = self.process(data['input'])

    # Good: Preserve input keys
    transformed_data = {
        **data,  # Keep original keys
        'output': result,
    }

    return transformed_data, state, metadata

3. Use Appropriate Temperature¤

Use Case Recommended Temperature
Training start 5.0 - 10.0
Training end 0.1 - 1.0
Inference (soft) 1.0
Inference (hard) 0.01

4. JIT for Performance¤

Always JIT-compile hot paths:

@jax.jit
def process_batch(operator, batch_data):
    results = []
    for data in batch_data:
        result, _, _ = operator.apply(data, {}, None)
        results.append(result)
    return results

Next Steps¤

Data Loading¤

  • Data Sources: Load genomics data (BAM, FASTA) and molecular datasets (MolNet)
  • Dataset Splitters: Domain-aware dataset splitting (scaffold, sequence identity)

Pipelines¤

Training¤