Variant Calling Pipeline API¤
End-to-end differentiable variant calling pipeline.
VariantCallingPipeline¤
diffbio.pipelines.variant_calling.VariantCallingPipeline
¤
VariantCallingPipeline(
config: VariantCallingPipelineConfig,
*,
rngs: Rngs,
name: str | None = None,
)
Bases: OperatorModule
End-to-end differentiable variant calling pipeline.
This pipeline processes sequencing reads to call variants:
Input data structure
- reads: Float[Array, "num_reads read_length 4"] - One-hot encoded reads
- positions: Int[Array, "num_reads"] - Read start positions on reference
- quality: Float[Array, "num_reads read_length"] - Base quality scores
Output data structure (adds): - pileup: Float[Array, "reference_length 4"] - Aggregated base frequencies - logits: Float[Array, "reference_length num_classes"] - Raw predictions - probabilities: Float[Array, "reference_length num_classes"] - Class probs
The pipeline is fully differentiable, supporting gradient-based training to optimize quality filtering, pileup aggregation, and classification jointly.
Example
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
VariantCallingPipelineConfig
|
Pipeline configuration |
required |
rngs
|
Rngs
|
Random number generators for parameter initialization |
required |
name
|
str | None
|
Optional name for the pipeline |
None
|
apply
¤
apply(
data: dict[str, Array],
state: dict[str, Any],
metadata: dict[str, Any] | None,
random_params: Any = None,
stats: dict[str, Any] | None = None,
) -> tuple[
dict[str, Array], dict[str, Any], dict[str, Any] | None
]
Apply the full variant calling pipeline to a single sample.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
dict[str, Array]
|
Input data containing: - reads: Float[Array, "num_reads read_length 4"] - positions: Int[Array, "num_reads"] - quality: Float[Array, "num_reads read_length"] |
required |
state
|
dict[str, Any]
|
Element state (passed through) |
required |
metadata
|
dict[str, Any] | None
|
Element metadata (passed through) |
required |
random_params
|
Any
|
Not used (deterministic pipeline) |
None
|
stats
|
dict[str, Any] | None
|
Optional statistics dict |
None
|
Returns:
| Type | Description |
|---|---|
dict[str, Array]
|
Tuple of (output_data, state, metadata) where output_data contains |
dict[str, Any]
|
all input keys plus pileup, logits, and probabilities. |
set_training
¤
Set pipeline training mode.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
training
|
bool
|
If True, enable dropout. If False, disable dropout. |
True
|
call_variants
¤
Convenience method to call variants from a batch.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
Batch
|
Input batch with reads, positions, quality |
required |
threshold
|
float
|
Probability threshold for variant calling |
0.5
|
Returns:
| Type | Description |
|---|---|
dict[str, Array]
|
Dict containing: - predictions: Int[Array, "batch reference_length"] - Predicted classes - probabilities: Float[Array, "batch reference_length num_classes"] - variant_positions: List of (batch_idx, position) tuples |
VariantCallingPipelineConfig¤
diffbio.pipelines.variant_calling.VariantCallingPipelineConfig
dataclass
¤
VariantCallingPipelineConfig(
reference_length: int = 100,
num_classes: int = 3,
quality_threshold: float = 20.0,
pileup_window_size: int = 11,
classifier_hidden_dim: int = 64,
use_quality_weights: bool = True,
classifier_type: str = MLP,
cnn_hidden_channels: tuple[int, ...] = (32, 64),
cnn_fc_dims: tuple[int, ...] = (64, 32),
apply_pileup_softmax: bool = True,
)
Bases: OperatorConfig
Configuration for the variant calling pipeline.
Attributes:
| Name | Type | Description |
|---|---|---|
reference_length |
int
|
Length of reference sequence |
num_classes |
int
|
Number of variant classes (default: 3 for ref/snp/indel) |
quality_threshold |
float
|
Initial quality score threshold for filtering |
pileup_window_size |
int
|
Window size for pileup context |
classifier_hidden_dim |
int
|
Hidden dimension for classifier MLP |
use_quality_weights |
bool
|
Whether to weight pileup by quality scores |
classifier_type |
str
|
Type of classifier (ClassifierType.MLP or ClassifierType.CNN) |
cnn_hidden_channels |
tuple[int, ...]
|
Hidden channels for CNN classifier |
cnn_fc_dims |
tuple[int, ...]
|
Fully connected layer dimensions for CNN |
apply_pileup_softmax |
bool
|
Whether to apply softmax to pileup output |
Factory Function¤
create_variant_calling_pipeline¤
diffbio.pipelines.variant_calling.create_variant_calling_pipeline
¤
create_variant_calling_pipeline(
reference_length: int = 100,
num_classes: int = 3,
quality_threshold: float = 20.0,
hidden_dim: int = 64,
classifier_type: str = MLP,
pileup_window_size: int = 11,
apply_pileup_softmax: bool = True,
seed: int = 42,
) -> VariantCallingPipeline
Factory function to create a variant calling pipeline.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
reference_length
|
int
|
Length of reference sequence |
100
|
num_classes
|
int
|
Number of variant classes |
3
|
quality_threshold
|
float
|
Quality score threshold |
20.0
|
hidden_dim
|
int
|
Hidden dimension for classifier |
64
|
classifier_type
|
str
|
Type of classifier (ClassifierType.MLP or ClassifierType.CNN) |
MLP
|
pileup_window_size
|
int
|
Window size for pileup context |
11
|
apply_pileup_softmax
|
bool
|
Whether to apply softmax to pileup (False is better for variant detection as it preserves raw coverage-weighted signals) |
True
|
seed
|
int
|
Random seed |
42
|
Returns:
| Type | Description |
|---|---|
VariantCallingPipeline
|
Configured VariantCallingPipeline instance |
Usage Examples¤
Quick Start¤
from diffbio.pipelines import create_variant_calling_pipeline
import jax
import jax.numpy as jnp
# Create pipeline
pipeline = create_variant_calling_pipeline(
reference_length=100,
num_classes=3,
)
# Prepare data
data = {
"reads": jax.nn.softmax(
jax.random.uniform(jax.random.PRNGKey(0), (20, 30, 4)),
axis=-1
),
"positions": jax.random.randint(jax.random.PRNGKey(1), (20,), 0, 70),
"quality": jax.random.uniform(jax.random.PRNGKey(2), (20, 30), minval=10, maxval=40),
}
# Run pipeline
result, _, _ = pipeline.apply(data, {}, None)
predictions = jnp.argmax(result["probabilities"], axis=-1)
Full Configuration¤
from diffbio.pipelines import VariantCallingPipeline, VariantCallingPipelineConfig
from flax import nnx
config = VariantCallingPipelineConfig(
reference_length=10000,
num_classes=3,
quality_threshold=20.0,
pileup_window_size=21,
classifier_hidden_dim=128,
use_quality_weights=True,
)
pipeline = VariantCallingPipeline(config, rngs=nnx.Rngs(42))
pipeline.eval_mode()
Training Mode¤
# Enable dropout
pipeline.train_mode()
# Training loop
for batch in dataloader:
loss = train_step(pipeline, batch)
# Disable dropout for inference
pipeline.eval_mode()
Access Components¤
# Quality filter threshold
pipeline.quality_filter.threshold[...]
# Pileup temperature
pipeline.pileup.config.temperature
# Classifier network
pipeline.classifier
Input Specifications¤
| Key | Shape | Description |
|---|---|---|
reads |
(num_reads, read_length, 4) | One-hot encoded reads |
positions |
(num_reads,) | Read start positions |
quality |
(num_reads, read_length) | Phred quality scores |
Output Specifications¤
| Key | Shape | Description |
|---|---|---|
reads |
(num_reads, read_length, 4) | Original reads |
positions |
(num_reads,) | Original positions |
quality |
(num_reads, read_length) | Original quality |
filtered_reads |
(num_reads, read_length, 4) | Quality-filtered reads |
filtered_quality |
(num_reads, read_length) | Filtered quality scores |
pileup |
(reference_length, 4) | Aggregated pileup |
logits |
(reference_length, num_classes) | Raw predictions |
probabilities |
(reference_length, num_classes) | Class probabilities |