Enhanced Variant Calling Pipeline¤
The EnhancedVariantCallingPipeline is an end-to-end differentiable pipeline for advanced variant calling, featuring DeepVariant-style CNN classification and VQSR-style quality recalibration.
Overview¤
The pipeline processes sequencing reads through four stages:
graph LR
A[Reads] --> B[Quality Filter]
B --> C[Pileup Generation]
C --> D[CNN Classifier]
D --> E[Quality Recalibration]
E --> F[Predictions]
style A fill:#d1fae5,stroke:#059669,color:#064e3b
style B fill:#e0e7ff,stroke:#4338ca,color:#312e81
style C fill:#dbeafe,stroke:#2563eb,color:#1e3a5f
style D fill:#ede9fe,stroke:#7c3aed,color:#4c1d95
style E fill:#fef3c7,stroke:#d97706,color:#78350f
style F fill:#d1fae5,stroke:#059669,color:#064e3b
- Quality Filter (optional): Soft-masks low-quality bases using learnable threshold
- Pileup Generation: Aggregates filtered reads into position-wise distributions
- CNN Classifier: DeepVariant-style convolutional neural network for variant classification
- Quality Recalibration (optional): VQSR-style GMM-based quality score recalibration
Quick Start¤
from diffbio.pipelines import create_enhanced_variant_calling_pipeline
import jax
import jax.numpy as jnp
# Create pipeline
pipeline = create_enhanced_variant_calling_pipeline(
reference_length=100,
num_classes=3, # ref/snp/indel
pileup_window_size=11,
)
# Prepare data
num_reads = 20
read_length = 30
data = {
"reads": jax.random.uniform(jax.random.PRNGKey(0), (num_reads, read_length, 4)),
"positions": jax.random.randint(jax.random.PRNGKey(1), (num_reads,), 0, 70),
"quality": jax.random.uniform(jax.random.PRNGKey(2), (num_reads, read_length), 10, 40),
}
data["reads"] = jax.nn.softmax(data["reads"], axis=-1) # Normalize
# Run pipeline
result, _, _ = pipeline.apply(data, {}, None)
print(f"Pileup shape: {result['pileup'].shape}") # (100, 4)
print(f"Predictions shape: {result['probabilities'].shape}") # (100, 3)
print(f"Quality scores shape: {result['quality_scores'].shape}") # (100,)
Configuration¤
EnhancedVariantCallingPipelineConfig¤
| Parameter | Type | Default | Description |
|---|---|---|---|
reference_length |
int | 1000 | Length of reference sequence |
num_classes |
int | 3 | Number of variant classes (ref/SNP/indel) |
quality_threshold |
float | 20.0 | Initial quality threshold for filtering |
pileup_window_size |
int | 11 | Context window for pileup and CNN |
cnn_input_height |
int | 100 | Height of pileup image for CNN |
cnn_hidden_channels |
list[int] | [64, 128, 256] | CNN hidden channels |
cnn_fc_dims |
list[int] | [256, 128] | Fully connected layer dimensions |
cnn_dropout_rate |
float | 0.1 | Dropout rate for CNN classifier |
quality_recal_n_components |
int | 3 | GMM components for quality recalibration |
quality_recal_n_features |
int | 4 | Features for quality recalibration |
quality_recal_threshold |
float | 0.5 | Threshold for quality filtering |
enable_preprocessing |
bool | True | Enable quality filtering preprocessing |
enable_quality_recalibration |
bool | True | Enable quality recalibration |
from diffbio.pipelines import EnhancedVariantCallingPipeline, EnhancedVariantCallingPipelineConfig
from flax import nnx
config = EnhancedVariantCallingPipelineConfig(
reference_length=10000,
num_classes=3,
quality_threshold=20.0,
pileup_window_size=21,
cnn_hidden_channels=(64, 128, 256),
cnn_fc_dims=(256, 128),
cnn_dropout_rate=0.2,
enable_preprocessing=True,
enable_quality_recalibration=True,
)
pipeline = EnhancedVariantCallingPipeline(config, rngs=nnx.Rngs(42))
Input Format¤
The pipeline expects a dictionary with three keys:
reads¤
One-hot encoded reads with shape (num_reads, read_length, 4):
# Hard one-hot encoding
read_indices = jnp.array([0, 1, 2, 3, ...]) # A=0, C=1, G=2, T=3
reads = jnp.eye(4)[read_indices]
# Or soft probabilities from base calling
reads = base_caller_output # Already (num_reads, read_length, 4)
positions¤
Starting positions for each read:
positions = jnp.array([100, 250, 430, ...]) # (num_reads,)
# Read i covers positions[i] to positions[i] + read_length - 1
quality¤
Phred quality scores for each base:
quality = jnp.array([
[30, 35, 28, 40, ...], # Read 1 qualities
[25, 30, 32, 35, ...], # Read 2 qualities
]) # (num_reads, read_length)
Output Format¤
The pipeline returns a dictionary with:
| Key | Shape | Description |
|---|---|---|
reads |
(num_reads, read_length, 4) | Original reads |
positions |
(num_reads,) | Original positions |
quality |
(num_reads, read_length) | Original quality |
pileup |
(reference_length, 4) | Aggregated nucleotide distribution |
logits |
(reference_length, num_classes) | Raw classifier output |
probabilities |
(reference_length, num_classes) | Softmax probabilities |
quality_scores |
(reference_length,) | Recalibrated quality scores* |
filter_weights |
(reference_length,) | Soft filter weights* |
*Only present when enable_quality_recalibration=True
Pipeline Stages¤
Stage 1: Quality Filtering (Optional)¤
The quality filter applies soft masking based on quality scores:
# Internal operation
retention_weight = sigmoid(quality - threshold)
filtered_reads = reads * retention_weight
Access the threshold:
Stage 2: Pileup Generation¤
Aggregates reads into position-wise nucleotide distributions:
# At each reference position, compute weighted sum of read bases
# weighted by quality and position overlap
pileup[pos] = softmax(weighted_base_counts[pos])
Stage 3: CNN Classification¤
DeepVariant-style convolutional neural network classifier:
# For each position, create a pileup image
# Extract context window around position
# Apply CNN layers: Conv2D → BatchNorm → ReLU → Pool
# Classify using fully connected layers
logits = cnn_classifier(pileup_image) # (num_classes,)
probs = softmax(logits)
The CNN architecture:
- Input: Pileup images
(reference_length, height, window_size, 4) - Convolutional layers with increasing channels
- Batch normalization and ReLU activations
- Max pooling for spatial reduction
- Fully connected layers with dropout
Classes:
- 0: Reference (no variant)
- 1: SNP (single nucleotide polymorphism)
- 2: Indel (insertion/deletion)
Stage 4: Quality Recalibration (Optional)¤
VQSR-style quality score recalibration using a Gaussian Mixture Model:
# Extract variant features
features = [depth, max_prob, entropy, strand_balance]
# Apply GMM-based quality recalibration
quality_scores = gmm_quality_model(features)
filter_weights = sigmoid(quality_scores - threshold)
Features used for recalibration:
- Depth: Total coverage at each position
- Max probability: Confidence of prediction
- Entropy: Uncertainty in predictions
- Strand balance: Proxy for strand bias
Training¤
Using the Trainer Class¤
from diffbio.pipelines import create_enhanced_variant_calling_pipeline
from diffbio.utils.training import (
Trainer, TrainingConfig, cross_entropy_loss,
create_synthetic_training_data, data_iterator
)
# Create pipeline
pipeline = create_enhanced_variant_calling_pipeline(reference_length=100)
# Create trainer
trainer = Trainer(
pipeline,
TrainingConfig(
learning_rate=1e-3,
num_epochs=50,
log_every=10,
grad_clip_norm=1.0,
)
)
# Generate training data
inputs, targets = create_synthetic_training_data(
num_samples=100,
reference_length=100,
)
# Define loss function
def loss_fn(predictions, targets):
return cross_entropy_loss(
predictions["logits"],
targets["labels"],
num_classes=3,
)
# Train
trainer.train(
data_iterator_fn=lambda: data_iterator(inputs, targets),
loss_fn=loss_fn,
)
# Get trained pipeline
trained_pipeline = trainer.pipeline
Train vs Eval Mode¤
# Enable dropout during training
pipeline.train_mode()
for batch in train_dataloader:
loss = train_step(pipeline, batch)
# Disable dropout for inference
pipeline.eval_mode()
result, _, _ = pipeline.apply(test_data, {}, None)
Inference¤
Single Sample¤
pipeline.eval_mode()
result, _, _ = pipeline.apply(data, {}, None)
predictions = jnp.argmax(result["probabilities"], axis=-1)
# Find variant positions
variant_positions = jnp.where(predictions > 0)[0]
print(f"Variants at positions: {variant_positions}")
# Get quality-filtered variants (if recalibration enabled)
if "filter_weights" in result:
high_quality = result["filter_weights"] > 0.5
high_quality_variants = jnp.where((predictions > 0) & high_quality)[0]
JIT Compilation for Performance¤
@jax.jit
def fast_inference(pipeline, data):
result, _, _ = pipeline.apply(data, {}, None)
return result["probabilities"], result.get("quality_scores")
# Pre-compile
_ = fast_inference(pipeline, sample_data)
# Fast subsequent calls
preds, quality = fast_inference(pipeline, real_data)
Accessing Components¤
The pipeline exposes its sub-components for inspection and modification:
# Quality filter (if enabled)
if pipeline.quality_filter is not None:
pipeline.quality_filter.threshold[...] # Current threshold
# Pileup generator
pipeline.pileup.temperature[...] # Pileup temperature
# CNN classifier
pipeline.cnn_classifier # Neural network module
# Quality recalibration (if enabled)
if pipeline.quality_recalibration is not None:
pipeline.quality_recalibration # GMM-based quality model
Modifying Components¤
# Update quality threshold
if pipeline.quality_filter is not None:
pipeline.quality_filter.threshold[...] = 25.0
# Access all learnable parameters
from flax import nnx
params = nnx.state(pipeline, nnx.Param)
print(jax.tree.map(lambda x: x.shape, params))
Comparison with Basic Pipeline¤
| Feature | VariantCallingPipeline | EnhancedVariantCallingPipeline |
|---|---|---|
| Classifier | MLP | CNN (DeepVariant-style) |
| Quality Recalibration | No | Yes (VQSR-style) |
| Input Representation | Flat window | Pileup images |
| Dropout | No | Yes |
| Configurable Stages | No | Yes (enable_* flags) |
References¤
-
Poplin, R. et al. (2018). "A universal SNP and small-indel variant caller using deep neural networks." Nature Biotechnology.
-
Van der Auwera, G.A. & O'Connor, B.D. (2020). "Genomics in the Cloud: Using Docker, GATK, and WDL in Terra." - VQSR methodology.
-
DePristo, M.A. et al. (2011). "A framework for variation discovery and genotyping using next-generation DNA sequencing data." Nature Genetics.