Operators Overview
DiffBio provides a collection of differentiable operators for bioinformatics analysis. Each operator inherits from Datarax's OperatorModule for consistent interfaces and composability.
Foundation: Soft Operations
All DiffBio operators are built on a shared layer of differentiable primitives in diffbio.core.soft_ops. This module provides 79 smooth relaxations of discrete operations (sorting, comparisons, logical gates, selection) with 5 smoothness modes and multiple algorithmic backends.
Operators do not call JAX's hard jnp.max, jnp.argmax, or jnp.sort directly. Instead, they use soft_ops equivalents that produce well-defined gradients:
from diffbio.core import soft_ops
# Soft quality threshold (returns probability, not boolean)
is_good = soft_ops.greater(quality, 20.0, softness=1.0)
# Soft sorting for gene ranking
ranked = soft_ops.sort(expression, axis=0, softness=0.1)
# Soft top-k variant selection
values, indices = soft_ops.top_k(scores, k=10, softness=0.1)
See the Soft Operations concept guide for a detailed walkthrough, or the API reference for the full function list.
Available Operators
Core Operators
| Operator |
Description |
Status |
SoftProgressiveMSA |
Differentiable multiple sequence alignment with guide tree |
Implemented |
ProfileHMMSearch |
Profile Hidden Markov Model for sequence homology |
Implemented |
| Operator |
Description |
Status |
DifferentiablePeakCaller |
CNN-based peak calling for ChIP-seq/ATAC-seq |
Implemented |
ChromatinStateAnnotator |
HMM-based chromatin state classification |
Implemented |
| Operator |
Description |
Status |
SplicingPSI |
Differentiable PSI calculation for alternative splicing |
Implemented |
DifferentiableMotifDiscovery |
Learnable PWM-based motif discovery |
Implemented |
| Operator |
Description |
Status |
SoftKMeansClustering |
Differentiable soft k-means with learnable centroids |
Implemented |
DifferentiableHarmony |
Harmony-style batch correction |
Implemented |
DifferentiableVelocity |
RNA velocity via neural ODEs |
Implemented |
DifferentiableAmbientRemoval |
VAE-based ambient RNA decontamination |
Implemented |
| Operator |
Description |
Status |
SoftAdapterRemoval |
Differentiable adapter trimming with soft alignment |
Implemented |
DifferentiableDuplicateWeighting |
Probabilistic duplicate weighting |
Implemented |
SoftErrorCorrection |
Neural network-based error correction |
Implemented |
| Operator |
Description |
Status |
VAENormalizer |
scVI-style VAE for count normalization |
Implemented |
DifferentiableUMAP |
Differentiable UMAP dimensionality reduction |
Implemented |
SequenceEmbedding |
Learned sequence embeddings |
Implemented |
| Operator |
Description |
Status |
DifferentiableHMM |
Forward algorithm with logsumexp stability |
Implemented |
DifferentiableNBGLM |
Negative binomial GLM for differential expression |
Implemented |
DifferentiableEMQuantifier |
Unrolled EM for transcript quantification |
Implemented |
| Operator |
Description |
Status |
GNNAssemblyNavigator |
GNN for assembly graph traversal |
Implemented |
NeuralReadMapper |
Cross-attention based read mapping |
Implemented |
DifferentiableMetagenomicBinner |
VAMB-style VAE for metagenomic binning |
Implemented |
| Operator |
Description |
Status |
SpatialDeconvolution |
Cell type deconvolution for spatial transcriptomics |
Implemented |
HiCContactAnalysis |
Chromatin contact analysis for Hi-C data |
Implemented |
DifferentiableSpatialGeneDetector |
SpatialDE-style spatial gene detection |
Implemented |
| Operator |
Description |
Status |
CNNVariantClassifier |
CNN-based variant classification |
Implemented |
DifferentiableCNVSegmentation |
Copy number variation segmentation |
Implemented |
SoftVariantQualityFilter |
Base quality score recalibration |
Implemented |
DeepVariantStylePileup |
Multi-channel pileup image generation for DeepVariant-style CNNs |
Implemented |
| Operator |
Description |
Status |
DifferentiableAncestryEstimator |
Neural ADMIXTURE-style ancestry estimation |
Implemented |
| Operator |
Description |
Status |
DifferentiableCRISPRScorer |
DeepCRISPR-style guide RNA efficiency prediction |
Implemented |
| Operator |
Description |
Status |
DifferentiableSpectralSimilarity |
MS2DeepScore-style Siamese network for MS/MS similarity |
Implemented |
| Operator |
Description |
Status |
DifferentiableSecondaryStructure |
PyDSSP-style DSSP with continuous H-bond matrix |
Implemented |
| Operator |
Description |
Status |
TransformerSequenceEncoder |
DNABERT/RNA-FM-style transformer for sequence embedding |
Implemented |
| Operator |
Description |
Status |
DifferentiableRNAFold |
McCaskill-style partition function for base pair probabilities |
Implemented |
| Operator |
Description |
Status |
ForceFieldOperator |
Differentiable force field (LJ, Morse, Soft Sphere) using JAX-MD |
Implemented |
MDIntegratorOperator |
Time integration for MD (velocity Verlet, Langevin) using JAX-MD |
Implemented |
| Operator |
Description |
Status |
MolecularPropertyPredictor |
ChemProp-style D-MPNN for molecular property prediction |
Implemented |
ADMETPredictor |
Multi-task ADMET prediction (22 TDC endpoints) |
Implemented |
DifferentiableMolecularFingerprint |
Neural graph fingerprints as alternative to ECFP/Morgan |
Implemented |
CircularFingerprintOperator |
Differentiable ECFP/Morgan circular fingerprints |
Implemented |
MACCSKeysOperator |
Differentiable MACCS 166 structural keys fingerprint |
Implemented |
AttentiveFP |
Attention-based graph fingerprint with GRU (Xiong et al. 2019) |
Implemented |
MolecularSimilarityOperator |
Differentiable Tanimoto/cosine/Dice similarity |
Implemented |
Operator Interface
All DiffBio operators implement the Datarax OperatorModule interface:
class OperatorModule:
def apply(
self,
data: PyTree,
state: PyTree,
metadata: dict | None,
random_params: Any = None,
stats: dict | None = None,
) -> tuple[PyTree, PyTree, dict | None]:
"""Transform data through the operator.
Args:
data: Input data as a PyTree (typically dict)
state: Per-element state (passed through or modified)
metadata: Optional metadata
random_params: Random parameters for stochastic operators
stats: Optional statistics dictionary
Returns:
Tuple of (transformed_data, updated_state, updated_metadata)
"""
Configuration Pattern
Each operator has a corresponding configuration dataclass:
from dataclasses import dataclass
from datarax.core.config import OperatorConfig
@dataclass(frozen=True)
class MyOperatorConfig(OperatorConfig):
"""Configuration for MyOperator."""
temperature: float = 1.0
Example Usage
from diffbio.operators import DifferentiableQualityFilter, QualityFilterConfig
# 1. Create configuration
config = QualityFilterConfig(initial_threshold=20.0)
# 2. Instantiate operator
operator = DifferentiableQualityFilter(config)
# 3. Prepare data
data = {
"sequence": sequence_tensor,
"quality_scores": quality_tensor,
}
# 4. Apply operator
result_data, state, metadata = operator.apply(data, {}, None)
Composing Operators
Operators can be composed into pipelines using Datarax's composition utilities:
Sequential Composition
from datarax.operators import CompositeOperatorModule
# Chain operators sequentially
pipeline = CompositeOperatorModule([
quality_filter,
aligner,
pileup_generator,
])
# Apply entire pipeline
result, state, meta = pipeline.apply(data, {}, None)
Manual Composition
def my_pipeline(data):
# Step 1: Quality filtering
data, state, meta = quality_filter.apply(data, {}, None)
# Step 2: Alignment
data, state, meta = aligner.apply(data, state, meta)
# Step 3: Pileup
data, state, meta = pileup_op.apply(data, state, meta)
return data
Learnable Parameters
DiffBio operators use Flax NNX for parameter management:
Accessing Parameters
from flax import nnx
# Get all parameters
params = nnx.state(operator, nnx.Param)
print(params)
# Access specific parameter
print(operator.threshold[...]) # Array value
Updating Parameters
# Manual update
operator.threshold[...] = new_value
# Gradient-based update
import optax
optimizer = optax.adam(learning_rate=0.001)
opt_state = optimizer.init(params)
def update_step(params, grads, opt_state):
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
return params, opt_state
All operators are compatible with JAX transformations:
JIT Compilation
import jax
@jax.jit
def apply_operator(data):
result, _, _ = operator.apply(data, {}, None)
return result
# Fast execution after first compile
result = apply_operator(data)
Vectorization
# Process batch of inputs
def single_apply(single_data):
result, _, _ = operator.apply(single_data, {}, None)
return result
batch_apply = jax.vmap(single_apply)
batch_results = batch_apply(batch_data)
Gradient Computation
def loss_fn(data):
result, _, _ = operator.apply(data, {}, None)
return result['score'].mean()
# Compute gradients w.r.t. operator parameters
grad_fn = jax.grad(loss_fn)
grads = grad_fn(data)
Best Practices
1. Use Configuration Objects
# Good: Use config dataclass
config = SmithWatermanConfig(
temperature=1.0,
gap_open=-10.0,
)
aligner = SmoothSmithWaterman(config, scoring_matrix=scoring)
# Avoid: Hardcoded values scattered in code
When implementing custom operators, preserve input data keys:
def apply(self, data, state, metadata, ...):
result = self.process(data['input'])
# Good: Preserve input keys
transformed_data = {
**data, # Keep original keys
'output': result,
}
return transformed_data, state, metadata
3. Use Appropriate Temperature
| Use Case |
Recommended Temperature |
| Training start |
5.0 - 10.0 |
| Training end |
0.1 - 1.0 |
| Inference (soft) |
1.0 |
| Inference (hard) |
0.01 |
Always JIT-compile hot paths:
@jax.jit
def process_batch(operator, batch_data):
results = []
for data in batch_data:
result, _, _ = operator.apply(data, {}, None)
results.append(result)
return results
Next Steps
Data Loading
- Data Sources: Load genomics data (BAM, FASTA) and molecular datasets (MolNet)
- Dataset Splitters: Domain-aware dataset splitting (scaffold, sequence identity)
Pipelines
Training