Enhanced Variant Calling Pipeline API¤
DeepVariant-style end-to-end differentiable variant calling pipeline with CNN classifier and quality recalibration.
EnhancedVariantCallingPipeline¤
diffbio.pipelines.enhanced_variant_calling.EnhancedVariantCallingPipeline
¤
EnhancedVariantCallingPipeline(
config: EnhancedVariantCallingPipelineConfig,
*,
rngs: Rngs,
name: str | None = None,
)
Bases: OperatorModule
Enhanced end-to-end differentiable variant calling pipeline.
This pipeline processes sequencing reads to call variants using a DeepVariant-style CNN classifier followed by VQSR-style quality recalibration:
Input data structure
- reads: Float[Array, "num_reads read_length 4"] - One-hot encoded reads
- positions: Int[Array, "num_reads"] - Read start positions on reference
- quality: Float[Array, "num_reads read_length"] - Base quality scores
Output data structure (adds): - pileup: Float[Array, "reference_length 4"] - Aggregated base frequencies - logits: Float[Array, "reference_length num_classes"] - Raw predictions - probabilities: Float[Array, "reference_length num_classes"] - Class probs - quality_scores: Float[Array, "reference_length"] - Recalibrated quality - filter_weights: Float[Array, "reference_length"] - Soft filter weights
The pipeline is fully differentiable, supporting gradient-based training to optimize all components jointly.
Example
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
EnhancedVariantCallingPipelineConfig
|
Pipeline configuration. |
required |
rngs
|
Rngs
|
Random number generators for parameter initialization. |
required |
name
|
str | None
|
Optional name for the pipeline. |
None
|
apply
¤
apply(
data: dict[str, Array],
state: dict[str, Any],
metadata: dict[str, Any] | None,
random_params: Any = None,
stats: dict[str, Any] | None = None,
) -> tuple[
dict[str, Array], dict[str, Any], dict[str, Any] | None
]
Apply the enhanced variant calling pipeline.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
dict[str, Array]
|
Input data containing: - reads: Float[Array, "num_reads read_length 4"] - positions: Int[Array, "num_reads"] - quality: Float[Array, "num_reads read_length"] |
required |
state
|
dict[str, Any]
|
Element state (passed through). |
required |
metadata
|
dict[str, Any] | None
|
Element metadata (passed through). |
required |
random_params
|
Any
|
Random parameters for stochastic operations. |
None
|
stats
|
dict[str, Any] | None
|
Optional statistics dict. |
None
|
Returns:
| Type | Description |
|---|---|
dict[str, Array]
|
Tuple of (output_data, state, metadata) where output_data contains |
dict[str, Any]
|
all input keys plus variant calling outputs. |
EnhancedVariantCallingPipelineConfig¤
diffbio.pipelines.enhanced_variant_calling.EnhancedVariantCallingPipelineConfig
dataclass
¤
EnhancedVariantCallingPipelineConfig(
reference_length: int = 1000,
num_classes: int = 3,
quality_threshold: float = 20.0,
pileup_window_size: int = 11,
cnn_input_height: int = 100,
cnn_hidden_channels: tuple[int, ...] = (64, 128, 256),
cnn_fc_dims: tuple[int, ...] = (256, 128),
cnn_dropout_rate: float = 0.1,
quality_recal_n_components: int = 3,
quality_recal_n_features: int = 4,
quality_recal_threshold: float = 0.5,
enable_preprocessing: bool = True,
enable_quality_recalibration: bool = True,
)
Bases: OperatorConfig
Configuration for the enhanced variant calling pipeline.
Attributes:
| Name | Type | Description |
|---|---|---|
reference_length |
int
|
Length of reference sequence. |
num_classes |
int
|
Number of variant classes (default: 3 for ref/snp/indel). |
quality_threshold |
float
|
Initial quality score threshold for filtering. |
pileup_window_size |
int
|
Window size for pileup context. |
cnn_input_height |
int
|
Height of pileup image for CNN (coverage depth). |
cnn_hidden_channels |
tuple[int, ...]
|
Hidden channels for CNN classifier. |
cnn_fc_dims |
tuple[int, ...]
|
Fully connected layer dimensions for CNN. |
cnn_dropout_rate |
float
|
Dropout rate for CNN classifier. |
quality_recal_n_components |
int
|
Number of GMM components for quality recalibration. |
quality_recal_n_features |
int
|
Number of features for quality recalibration. |
quality_recal_threshold |
float
|
Threshold for quality filtering. |
enable_preprocessing |
bool
|
Whether to enable quality filtering preprocessing. |
enable_quality_recalibration |
bool
|
Whether to enable quality recalibration. |
Factory Function¤
create_enhanced_variant_calling_pipeline¤
diffbio.pipelines.enhanced_variant_calling.create_enhanced_variant_calling_pipeline
¤
create_enhanced_variant_calling_pipeline(
reference_length: int = 1000,
num_classes: int = 3,
pileup_window_size: int = 11,
cnn_hidden_channels: tuple[int, ...] | None = None,
cnn_fc_dims: tuple[int, ...] | None = None,
enable_preprocessing: bool = True,
enable_quality_recalibration: bool = True,
seed: int = 42,
) -> EnhancedVariantCallingPipeline
Factory function to create an enhanced variant calling pipeline.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
reference_length
|
int
|
Length of reference sequence. |
1000
|
num_classes
|
int
|
Number of variant classes. |
3
|
pileup_window_size
|
int
|
Window size for pileup context. |
11
|
cnn_hidden_channels
|
tuple[int, ...] | None
|
Hidden channels for CNN classifier. |
None
|
cnn_fc_dims
|
tuple[int, ...] | None
|
Fully connected dimensions for CNN. |
None
|
enable_preprocessing
|
bool
|
Whether to enable quality filtering. |
True
|
enable_quality_recalibration
|
bool
|
Whether to enable quality recalibration. |
True
|
seed
|
int
|
Random seed. |
42
|
Returns:
| Type | Description |
|---|---|
EnhancedVariantCallingPipeline
|
Configured EnhancedVariantCallingPipeline instance. |
Usage Examples¤
Quick Start¤
from diffbio.pipelines import create_enhanced_variant_calling_pipeline
import jax
import jax.numpy as jnp
# Create pipeline
pipeline = create_enhanced_variant_calling_pipeline(
reference_length=100,
num_classes=3,
pileup_window_size=11,
)
# Prepare data
data = {
"reads": jax.nn.softmax(
jax.random.uniform(jax.random.PRNGKey(0), (20, 30, 4)),
axis=-1
),
"positions": jax.random.randint(jax.random.PRNGKey(1), (20,), 0, 70),
"quality": jax.random.uniform(jax.random.PRNGKey(2), (20, 30), minval=10, maxval=40),
}
# Run pipeline
result, _, _ = pipeline.apply(data, {}, None)
predictions = jnp.argmax(result["probabilities"], axis=-1)
Full Configuration¤
from diffbio.pipelines import EnhancedVariantCallingPipeline, EnhancedVariantCallingPipelineConfig
from flax import nnx
config = EnhancedVariantCallingPipelineConfig(
reference_length=10000,
num_classes=3,
quality_threshold=20.0,
pileup_window_size=21,
cnn_hidden_channels=(64, 128, 256),
cnn_fc_dims=(256, 128),
cnn_dropout_rate=0.2,
enable_preprocessing=True,
enable_quality_recalibration=True,
)
pipeline = EnhancedVariantCallingPipeline(config, rngs=nnx.Rngs(42))
# Note: this pipeline has no training-mode toggle; dropout state is managed
# by submodules directly when applicable.
Training Mode¤
# EnhancedVariantCallingPipeline does not expose train_mode/eval_mode toggles.
# Submodules that use dropout manage their own state during apply().
for batch in dataloader:
loss = train_step(pipeline, batch)
Access Components¤
# Quality filter (if enabled)
if pipeline.quality_filter is not None:
pipeline.quality_filter.threshold[...]
# Pileup generator
pipeline.pileup.config.temperature
# CNN classifier
pipeline.cnn_classifier
# Quality recalibration (if enabled)
if pipeline.quality_recalibration is not None:
pipeline.quality_recalibration
Input Specifications¤
| Key | Shape | Description |
|---|---|---|
reads |
(num_reads, read_length, 4) | One-hot encoded reads |
positions |
(num_reads,) | Read start positions |
quality |
(num_reads, read_length) | Phred quality scores |
Output Specifications¤
| Key | Shape | Description |
|---|---|---|
reads |
(num_reads, read_length, 4) | Original reads |
positions |
(num_reads,) | Original positions |
quality |
(num_reads, read_length) | Original quality |
pileup |
(reference_length, 4) | Aggregated pileup |
logits |
(reference_length, num_classes) | Raw predictions |
probabilities |
(reference_length, num_classes) | Class probabilities |
quality_scores |
(reference_length,) | Recalibrated quality* |
filter_weights |
(reference_length,) | Soft filter weights* |
*Only present when enable_quality_recalibration=True