Skip to content

Pipelines Overview¤

DiffBio pipelines compose multiple operators into end-to-end differentiable workflows. This enables joint optimization of all pipeline components.

Available Pipelines¤

Pipeline Description Status
Variant Calling Reads → Pileup → Variants Implemented
Enhanced Variant Calling DeepVariant-style CNN variant calling with quality recalibration Implemented
Single-Cell Analysis scVI-style VAE + Harmony batch correction + soft clustering Implemented
Preprocessing Quality filtering → Adapter removal → Error correction Implemented
Differential Expression DESeq2-style differential expression analysis Implemented
Perturbation Cell-load-style perturbation experiment loading and downstream evaluation Implemented

Pipeline Architecture¤

DiffBio pipelines follow a modular architecture:

graph TB
    subgraph Input
        A[Raw Data]
    end

    subgraph Pipeline
        B[Preprocessing]
        C[Core Operation]
        D[Classification]
        B --> C --> D
    end

    subgraph Output
        E[Predictions]
    end

    A --> B
    D --> E

    style A fill:#d1fae5,stroke:#059669,color:#064e3b
    style B fill:#e0e7ff,stroke:#4338ca,color:#312e81
    style C fill:#e0e7ff,stroke:#4338ca,color:#312e81
    style D fill:#ede9fe,stroke:#7c3aed,color:#4c1d95
    style E fill:#d1fae5,stroke:#059669,color:#064e3b

Pipeline Benefits¤

End-to-End Optimization¤

Traditional pipelines optimize each step independently:

Input → [Optimize A] → [Optimize B] → [Optimize C] → Output
           ↑              ↑              ↑
        (separate)    (separate)     (separate)

DiffBio optimizes the entire pipeline jointly:

Input → [A] → [B] → [C] → Output
          ↖    ↑    ↗
            (joint optimization)

Gradient Flow¤

Gradients flow through all pipeline stages:

def pipeline_loss(params, data, targets):
    # Step 1: Quality filtering (gradients to threshold)
    # Step 2: Pileup (gradients to weighting)
    # Step 3: Classification (gradients to classifier weights)
    predictions = pipeline(params, data)
    return loss(predictions, targets)

# Single gradient computation optimizes all components
grads = jax.grad(pipeline_loss)(params, data, targets)

Creating Pipelines¤

Using Factory Functions¤

from diffbio.pipelines import create_variant_calling_pipeline

# Quick creation with sensible defaults
pipeline = create_variant_calling_pipeline(
    reference_length=1000,
    num_classes=3,
    quality_threshold=20.0,
)

Manual Configuration¤

from diffbio.pipelines import VariantCallingPipeline, VariantCallingPipelineConfig
from flax import nnx

# Full control over configuration
config = VariantCallingPipelineConfig(
    reference_length=1000,
    num_classes=3,
    quality_threshold=20.0,
    pileup_window_size=11,
    classifier_hidden_dim=64,
    use_quality_weights=True,
)

# Initialize with random number generator
rngs = nnx.Rngs(seed=42)
pipeline = VariantCallingPipeline(config, rngs=rngs)

Using Pipelines¤

Single Sample Processing¤

# Prepare input data
data = {
    "reads": reads,           # (num_reads, read_length, 4)
    "positions": positions,   # (num_reads,)
    "quality": quality,       # (num_reads, read_length)
}

# Apply pipeline
result_data, state, metadata = pipeline.apply(data, {}, None)

# Access outputs
pileup = result_data["pileup"]          # (reference_length, 4)
logits = result_data["logits"]          # (reference_length, num_classes)
probabilities = result_data["probabilities"]  # (reference_length, num_classes)

Batch Processing¤

from datarax.typing import Batch, Element

# Create batch
elements = [
    Element(data=sample_data, state={}, metadata={})
    for sample_data in samples
]
batch = Batch.from_elements(elements)

# Process batch
result_batch = pipeline.apply_batch(batch)

Training Mode¤

# Enable dropout and other training-specific behavior
pipeline.train_mode()

# ... training loop ...

# Disable for inference
pipeline.eval_mode()

Pipeline Components¤

Each pipeline composes multiple operators:

Variant Calling Pipeline¤

class VariantCallingPipeline:
    def __init__(self, config, rngs):
        # 1. Quality filter
        self.quality_filter = DifferentiableQualityFilter(...)

        # 2. Pileup generator
        self.pileup = DifferentiablePileup(...)

        # 3. Variant classifier
        self.classifier = VariantClassifier(...)

Accessing Sub-Components¤

pipeline = create_variant_calling_pipeline(reference_length=100)

# Access individual operators
print(pipeline.quality_filter.threshold)  # Quality threshold
print(pipeline.pileup.temperature)        # Pileup temperature
print(pipeline.classifier)                # Neural network classifier

Custom Pipelines¤

Create custom pipelines by composing operators:

from datarax.core.operator import OperatorModule
from flax import nnx

class CustomPipeline(OperatorModule):
    def __init__(self, config, rngs):
        super().__init__(config, rngs=rngs)

        # Initialize your operators
        self.op1 = Operator1(config.op1_config, rngs=rngs)
        self.op2 = Operator2(config.op2_config, rngs=rngs)
        self.op3 = Operator3(config.op3_config, rngs=rngs)

    def apply(self, data, state, metadata, random_params=None, stats=None):
        # Chain operators
        data, state, metadata = self.op1.apply(data, state, metadata)
        data, state, metadata = self.op2.apply(data, state, metadata)
        data, state, metadata = self.op3.apply(data, state, metadata)

        return data, state, metadata

Best Practices¤

1. Match Reference Lengths¤

All operators in a pipeline must agree on reference length:

# Consistent reference length
ref_len = 1000
pileup_config = PileupConfig(reference_length=ref_len)
pipeline_config = VariantCallingPipelineConfig(reference_length=ref_len)

2. Set Mode Appropriately¤

# Training
pipeline.train_mode()
for batch in train_data:
    loss = train_step(pipeline, batch)

# Evaluation
pipeline.eval_mode()
for batch in test_data:
    metrics = evaluate(pipeline, batch)

3. JIT Compile for Performance¤

@jax.jit
def predict(pipeline, data):
    result, _, _ = pipeline.apply(data, {}, None)
    return result["probabilities"]

# First call compiles, subsequent calls are fast
preds = predict(pipeline, sample_data)

4. Save and Load Checkpoints¤

import pickle
from flax import nnx

# Save
state = nnx.state(pipeline, nnx.Param)
with open("checkpoint.pkl", "wb") as f:
    pickle.dump(state, f)

# Load
with open("checkpoint.pkl", "rb") as f:
    state = pickle.load(f)
nnx.update(pipeline, state)

Next Steps¤