Skip to content

Variant Calling Pipeline API¤

End-to-end differentiable variant calling pipeline.

VariantCallingPipeline¤

diffbio.pipelines.variant_calling.VariantCallingPipeline ¤

VariantCallingPipeline(
    config: VariantCallingPipelineConfig,
    *,
    rngs: Rngs,
    name: str | None = None,
)

Bases: OperatorModule

End-to-end differentiable variant calling pipeline.

This pipeline processes sequencing reads to call variants:

Input data structure
  • reads: Float[Array, "num_reads read_length 4"] - One-hot encoded reads
  • positions: Int[Array, "num_reads"] - Read start positions on reference
  • quality: Float[Array, "num_reads read_length"] - Base quality scores

Output data structure (adds): - pileup: Float[Array, "reference_length 4"] - Aggregated base frequencies - logits: Float[Array, "reference_length num_classes"] - Raw predictions - probabilities: Float[Array, "reference_length num_classes"] - Class probs

The pipeline is fully differentiable, supporting gradient-based training to optimize quality filtering, pileup aggregation, and classification jointly.

Example
config = VariantCallingPipelineConfig(reference_length=100)
pipeline = VariantCallingPipeline(config, rngs=nnx.Rngs(42))
pipeline.eval_mode()  # Disable dropout for inference
# Process a batch of samples
result_batch = pipeline(input_batch)
probs = result_batch.data.get_value()["probabilities"]

Parameters:

Name Type Description Default
config VariantCallingPipelineConfig

Pipeline configuration

required
rngs Rngs

Random number generators for parameter initialization

required
name str | None

Optional name for the pipeline

None

apply ¤

apply(
    data: dict[str, Array],
    state: dict[str, Any],
    metadata: dict[str, Any] | None,
    random_params: Any = None,
    stats: dict[str, Any] | None = None,
) -> tuple[
    dict[str, Array], dict[str, Any], dict[str, Any] | None
]

Apply the full variant calling pipeline to a single sample.

Parameters:

Name Type Description Default
data dict[str, Array]

Input data containing: - reads: Float[Array, "num_reads read_length 4"] - positions: Int[Array, "num_reads"] - quality: Float[Array, "num_reads read_length"]

required
state dict[str, Any]

Element state (passed through)

required
metadata dict[str, Any] | None

Element metadata (passed through)

required
random_params Any

Not used (deterministic pipeline)

None
stats dict[str, Any] | None

Optional statistics dict

None

Returns:

Type Description
dict[str, Array]

Tuple of (output_data, state, metadata) where output_data contains

dict[str, Any]

all input keys plus pileup, logits, and probabilities.

set_training ¤

set_training(training: bool = True) -> None

Set pipeline training mode.

Parameters:

Name Type Description Default
training bool

If True, enable dropout. If False, disable dropout.

True

train_mode ¤

train_mode() -> None

Set pipeline to training mode (enables dropout).

eval_mode ¤

eval_mode() -> None

Set pipeline to evaluation mode (disables dropout).

call_variants ¤

call_variants(
    batch: Batch, threshold: float = 0.5
) -> dict[str, Array]

Convenience method to call variants from a batch.

Parameters:

Name Type Description Default
batch Batch

Input batch with reads, positions, quality

required
threshold float

Probability threshold for variant calling

0.5

Returns:

Type Description
dict[str, Array]

Dict containing: - predictions: Int[Array, "batch reference_length"] - Predicted classes - probabilities: Float[Array, "batch reference_length num_classes"] - variant_positions: List of (batch_idx, position) tuples

VariantCallingPipelineConfig¤

diffbio.pipelines.variant_calling.VariantCallingPipelineConfig dataclass ¤

VariantCallingPipelineConfig(
    reference_length: int = 100,
    num_classes: int = 3,
    quality_threshold: float = 20.0,
    pileup_window_size: int = 11,
    classifier_hidden_dim: int = 64,
    use_quality_weights: bool = True,
    classifier_type: str = MLP,
    cnn_hidden_channels: tuple[int, ...] = (32, 64),
    cnn_fc_dims: tuple[int, ...] = (64, 32),
    apply_pileup_softmax: bool = True,
)

Bases: OperatorConfig

Configuration for the variant calling pipeline.

Attributes:

Name Type Description
reference_length int

Length of reference sequence

num_classes int

Number of variant classes (default: 3 for ref/snp/indel)

quality_threshold float

Initial quality score threshold for filtering

pileup_window_size int

Window size for pileup context

classifier_hidden_dim int

Hidden dimension for classifier MLP

use_quality_weights bool

Whether to weight pileup by quality scores

classifier_type str

Type of classifier (ClassifierType.MLP or ClassifierType.CNN)

cnn_hidden_channels tuple[int, ...]

Hidden channels for CNN classifier

cnn_fc_dims tuple[int, ...]

Fully connected layer dimensions for CNN

apply_pileup_softmax bool

Whether to apply softmax to pileup output

Factory Function¤

create_variant_calling_pipeline¤

diffbio.pipelines.variant_calling.create_variant_calling_pipeline ¤

create_variant_calling_pipeline(
    reference_length: int = 100,
    num_classes: int = 3,
    quality_threshold: float = 20.0,
    hidden_dim: int = 64,
    classifier_type: str = MLP,
    pileup_window_size: int = 11,
    apply_pileup_softmax: bool = True,
    seed: int = 42,
) -> VariantCallingPipeline

Factory function to create a variant calling pipeline.

Parameters:

Name Type Description Default
reference_length int

Length of reference sequence

100
num_classes int

Number of variant classes

3
quality_threshold float

Quality score threshold

20.0
hidden_dim int

Hidden dimension for classifier

64
classifier_type str

Type of classifier (ClassifierType.MLP or ClassifierType.CNN)

MLP
pileup_window_size int

Window size for pileup context

11
apply_pileup_softmax bool

Whether to apply softmax to pileup (False is better for variant detection as it preserves raw coverage-weighted signals)

True
seed int

Random seed

42

Returns:

Type Description
VariantCallingPipeline

Configured VariantCallingPipeline instance

Usage Examples¤

Quick Start¤

from diffbio.pipelines import create_variant_calling_pipeline
import jax
import jax.numpy as jnp

# Create pipeline
pipeline = create_variant_calling_pipeline(
    reference_length=100,
    num_classes=3,
)

# Prepare data
data = {
    "reads": jax.nn.softmax(
        jax.random.uniform(jax.random.PRNGKey(0), (20, 30, 4)),
        axis=-1
    ),
    "positions": jax.random.randint(jax.random.PRNGKey(1), (20,), 0, 70),
    "quality": jax.random.uniform(jax.random.PRNGKey(2), (20, 30), minval=10, maxval=40),
}

# Run pipeline
result, _, _ = pipeline.apply(data, {}, None)
predictions = jnp.argmax(result["probabilities"], axis=-1)

Full Configuration¤

from diffbio.pipelines import VariantCallingPipeline, VariantCallingPipelineConfig
from flax import nnx

config = VariantCallingPipelineConfig(
    reference_length=10000,
    num_classes=3,
    quality_threshold=20.0,
    pileup_window_size=21,
    classifier_hidden_dim=128,
    use_quality_weights=True,
)

pipeline = VariantCallingPipeline(config, rngs=nnx.Rngs(42))
pipeline.eval_mode()

Training Mode¤

# Enable dropout
pipeline.train_mode()

# Training loop
for batch in dataloader:
    loss = train_step(pipeline, batch)

# Disable dropout for inference
pipeline.eval_mode()

Access Components¤

# Quality filter threshold
pipeline.quality_filter.threshold[...]

# Pileup temperature
pipeline.pileup.config.temperature

# Classifier network
pipeline.classifier

Input Specifications¤

Key Shape Description
reads (num_reads, read_length, 4) One-hot encoded reads
positions (num_reads,) Read start positions
quality (num_reads, read_length) Phred quality scores

Output Specifications¤

Key Shape Description
reads (num_reads, read_length, 4) Original reads
positions (num_reads,) Original positions
quality (num_reads, read_length) Original quality
filtered_reads (num_reads, read_length, 4) Quality-filtered reads
filtered_quality (num_reads, read_length) Filtered quality scores
pileup (reference_length, 4) Aggregated pileup
logits (reference_length, num_classes) Raw predictions
probabilities (reference_length, num_classes) Class probabilities