Skip to content

Enhanced Variant Calling Pipeline API¤

DeepVariant-style end-to-end differentiable variant calling pipeline with CNN classifier and quality recalibration.

EnhancedVariantCallingPipeline¤

diffbio.pipelines.enhanced_variant_calling.EnhancedVariantCallingPipeline ¤

EnhancedVariantCallingPipeline(
    config: EnhancedVariantCallingPipelineConfig,
    *,
    rngs: Rngs,
    name: str | None = None,
)

Bases: OperatorModule

Enhanced end-to-end differentiable variant calling pipeline.

This pipeline processes sequencing reads to call variants using a DeepVariant-style CNN classifier followed by VQSR-style quality recalibration:

Input data structure
  • reads: Float[Array, "num_reads read_length 4"] - One-hot encoded reads
  • positions: Int[Array, "num_reads"] - Read start positions on reference
  • quality: Float[Array, "num_reads read_length"] - Base quality scores

Output data structure (adds): - pileup: Float[Array, "reference_length 4"] - Aggregated base frequencies - logits: Float[Array, "reference_length num_classes"] - Raw predictions - probabilities: Float[Array, "reference_length num_classes"] - Class probs - quality_scores: Float[Array, "reference_length"] - Recalibrated quality - filter_weights: Float[Array, "reference_length"] - Soft filter weights

The pipeline is fully differentiable, supporting gradient-based training to optimize all components jointly.

Example
config = EnhancedVariantCallingPipelineConfig(reference_length=1000)
pipeline = EnhancedVariantCallingPipeline(config, rngs=nnx.Rngs(42))
result, state, meta = pipeline.apply(data, {}, None)
probs = result["probabilities"]

Parameters:

Name Type Description Default
config EnhancedVariantCallingPipelineConfig

Pipeline configuration.

required
rngs Rngs

Random number generators for parameter initialization.

required
name str | None

Optional name for the pipeline.

None

apply ¤

apply(
    data: dict[str, Array],
    state: dict[str, Any],
    metadata: dict[str, Any] | None,
    random_params: Any = None,
    stats: dict[str, Any] | None = None,
) -> tuple[
    dict[str, Array], dict[str, Any], dict[str, Any] | None
]

Apply the enhanced variant calling pipeline.

Parameters:

Name Type Description Default
data dict[str, Array]

Input data containing: - reads: Float[Array, "num_reads read_length 4"] - positions: Int[Array, "num_reads"] - quality: Float[Array, "num_reads read_length"]

required
state dict[str, Any]

Element state (passed through).

required
metadata dict[str, Any] | None

Element metadata (passed through).

required
random_params Any

Random parameters for stochastic operations.

None
stats dict[str, Any] | None

Optional statistics dict.

None

Returns:

Type Description
dict[str, Array]

Tuple of (output_data, state, metadata) where output_data contains

dict[str, Any]

all input keys plus variant calling outputs.

EnhancedVariantCallingPipelineConfig¤

diffbio.pipelines.enhanced_variant_calling.EnhancedVariantCallingPipelineConfig dataclass ¤

EnhancedVariantCallingPipelineConfig(
    reference_length: int = 1000,
    num_classes: int = 3,
    quality_threshold: float = 20.0,
    pileup_window_size: int = 11,
    cnn_input_height: int = 100,
    cnn_hidden_channels: tuple[int, ...] = (64, 128, 256),
    cnn_fc_dims: tuple[int, ...] = (256, 128),
    cnn_dropout_rate: float = 0.1,
    quality_recal_n_components: int = 3,
    quality_recal_n_features: int = 4,
    quality_recal_threshold: float = 0.5,
    enable_preprocessing: bool = True,
    enable_quality_recalibration: bool = True,
)

Bases: OperatorConfig

Configuration for the enhanced variant calling pipeline.

Attributes:

Name Type Description
reference_length int

Length of reference sequence.

num_classes int

Number of variant classes (default: 3 for ref/snp/indel).

quality_threshold float

Initial quality score threshold for filtering.

pileup_window_size int

Window size for pileup context.

cnn_input_height int

Height of pileup image for CNN (coverage depth).

cnn_hidden_channels tuple[int, ...]

Hidden channels for CNN classifier.

cnn_fc_dims tuple[int, ...]

Fully connected layer dimensions for CNN.

cnn_dropout_rate float

Dropout rate for CNN classifier.

quality_recal_n_components int

Number of GMM components for quality recalibration.

quality_recal_n_features int

Number of features for quality recalibration.

quality_recal_threshold float

Threshold for quality filtering.

enable_preprocessing bool

Whether to enable quality filtering preprocessing.

enable_quality_recalibration bool

Whether to enable quality recalibration.

Factory Function¤

create_enhanced_variant_calling_pipeline¤

diffbio.pipelines.enhanced_variant_calling.create_enhanced_variant_calling_pipeline ¤

create_enhanced_variant_calling_pipeline(
    reference_length: int = 1000,
    num_classes: int = 3,
    pileup_window_size: int = 11,
    cnn_hidden_channels: tuple[int, ...] | None = None,
    cnn_fc_dims: tuple[int, ...] | None = None,
    enable_preprocessing: bool = True,
    enable_quality_recalibration: bool = True,
    seed: int = 42,
) -> EnhancedVariantCallingPipeline

Factory function to create an enhanced variant calling pipeline.

Parameters:

Name Type Description Default
reference_length int

Length of reference sequence.

1000
num_classes int

Number of variant classes.

3
pileup_window_size int

Window size for pileup context.

11
cnn_hidden_channels tuple[int, ...] | None

Hidden channels for CNN classifier.

None
cnn_fc_dims tuple[int, ...] | None

Fully connected dimensions for CNN.

None
enable_preprocessing bool

Whether to enable quality filtering.

True
enable_quality_recalibration bool

Whether to enable quality recalibration.

True
seed int

Random seed.

42

Returns:

Type Description
EnhancedVariantCallingPipeline

Configured EnhancedVariantCallingPipeline instance.

Usage Examples¤

Quick Start¤

from diffbio.pipelines import create_enhanced_variant_calling_pipeline
import jax
import jax.numpy as jnp

# Create pipeline
pipeline = create_enhanced_variant_calling_pipeline(
    reference_length=100,
    num_classes=3,
    pileup_window_size=11,
)

# Prepare data
data = {
    "reads": jax.nn.softmax(
        jax.random.uniform(jax.random.PRNGKey(0), (20, 30, 4)),
        axis=-1
    ),
    "positions": jax.random.randint(jax.random.PRNGKey(1), (20,), 0, 70),
    "quality": jax.random.uniform(jax.random.PRNGKey(2), (20, 30), minval=10, maxval=40),
}

# Run pipeline
result, _, _ = pipeline.apply(data, {}, None)
predictions = jnp.argmax(result["probabilities"], axis=-1)

Full Configuration¤

from diffbio.pipelines import EnhancedVariantCallingPipeline, EnhancedVariantCallingPipelineConfig
from flax import nnx

config = EnhancedVariantCallingPipelineConfig(
    reference_length=10000,
    num_classes=3,
    quality_threshold=20.0,
    pileup_window_size=21,
    cnn_hidden_channels=(64, 128, 256),
    cnn_fc_dims=(256, 128),
    cnn_dropout_rate=0.2,
    enable_preprocessing=True,
    enable_quality_recalibration=True,
)

pipeline = EnhancedVariantCallingPipeline(config, rngs=nnx.Rngs(42))
# Note: this pipeline has no training-mode toggle; dropout state is managed
# by submodules directly when applicable.

Training Mode¤

# EnhancedVariantCallingPipeline does not expose train_mode/eval_mode toggles.
# Submodules that use dropout manage their own state during apply().
for batch in dataloader:
    loss = train_step(pipeline, batch)

Access Components¤

# Quality filter (if enabled)
if pipeline.quality_filter is not None:
    pipeline.quality_filter.threshold[...]

# Pileup generator
pipeline.pileup.config.temperature

# CNN classifier
pipeline.cnn_classifier

# Quality recalibration (if enabled)
if pipeline.quality_recalibration is not None:
    pipeline.quality_recalibration

Input Specifications¤

Key Shape Description
reads (num_reads, read_length, 4) One-hot encoded reads
positions (num_reads,) Read start positions
quality (num_reads, read_length) Phred quality scores

Output Specifications¤

Key Shape Description
reads (num_reads, read_length, 4) Original reads
positions (num_reads,) Original positions
quality (num_reads, read_length) Original quality
pileup (reference_length, 4) Aggregated pileup
logits (reference_length, num_classes) Raw predictions
probabilities (reference_length, num_classes) Class probabilities
quality_scores (reference_length,) Recalibrated quality*
filter_weights (reference_length,) Soft filter weights*

*Only present when enable_quality_recalibration=True