Pileup Operator¤
The DifferentiablePileup operator generates soft read pileups for variant calling, aggregating aligned reads into position-wise nucleotide distributions.
Pileup Fully Differentiable
Overview¤
Pileup generation aggregates aligned sequencing reads at each reference position. Unlike traditional pileup tools that produce integer counts, DiffBio's implementation produces continuous distributions that enable gradient flow for end-to-end training.
Quick Start¤
import jax
import jax.numpy as jnp
from diffbio.operators.variant import DifferentiablePileup, PileupConfig
# Configure pileup
config = PileupConfig(
reference_length=100,
use_quality_weights=True
)
# Create operator
pileup_op = DifferentiablePileup(config)
# Prepare data
num_reads = 20
read_length = 30
reads = jax.random.uniform(jax.random.PRNGKey(0), (num_reads, read_length, 4))
reads = jax.nn.softmax(reads, axis=-1) # Soft one-hot
positions = jax.random.randint(jax.random.PRNGKey(1), (num_reads,), 0, 70)
quality = jax.random.uniform(jax.random.PRNGKey(2), (num_reads, read_length), 10, 40)
# Generate pileup
data = {"reads": reads, "positions": positions, "quality": quality}
result, _, _ = pileup_op.apply(data, {}, None)
print(f"Pileup shape: {result['pileup'].shape}") # (100, 4)
Configuration¤
PileupConfig¤
| Parameter | Type | Default | Description |
|---|---|---|---|
reference_length |
int | 100 | Length of reference sequence |
use_quality_weights |
bool | True | Weight bases by quality scores |
return_coverage |
bool | False | Include soft coverage as an extra output |
return_quality |
bool | False | Include mean-quality channel as an extra output |
apply_softmax |
bool | True | Normalize base channels into distributions |
stochastic |
bool | False | Whether operator uses randomness |
from diffbio.operators.variant import PileupConfig
config = PileupConfig(
reference_length=10000, # Must match your reference
use_quality_weights=True, # Recommended for quality-aware pileup
return_coverage=True, # Emit soft coverage channel
return_quality=True, # Emit mean-quality channel
)
API Reference¤
DifferentiablePileup¤
class DifferentiablePileup(OperatorModule):
def __init__(
self,
config: PileupConfig,
*,
rngs: nnx.Rngs | None = None,
name: str | None = None,
):
"""Initialize differentiable pileup generator.
Args:
config: Pileup configuration
rngs: Random number generators (optional)
name: Optional operator name
"""
Methods¤
compute_pileup()¤
def compute_pileup(
self,
reads: Float[Array, "num_reads read_length 4"],
positions: Int[Array, "num_reads"],
quality: Float[Array, "num_reads read_length"],
reference_length: int,
) -> dict[str, Float[Array, "..."]]:
"""Generate pileup from aligned reads.
Args:
reads: One-hot encoded reads
positions: Starting position of each read
quality: Quality scores for each base
reference_length: Length of reference sequence
Returns:
Dictionary containing `"pileup"` and optional `"coverage"` /
`"mean_quality"` channels
"""
apply()¤
def apply(
self,
data: PyTree,
state: PyTree,
metadata: dict | None,
random_params: Any = None,
stats: dict | None = None,
) -> tuple[PyTree, PyTree, dict | None]:
"""Apply pileup generation (Datarax interface).
Expected data keys:
- "reads": One-hot encoded reads (num_reads, read_length, 4)
- "positions": Starting position of each read (num_reads,)
- "quality": Quality scores (num_reads, read_length)
Output data keys:
- "reads", "positions", "quality": Original inputs
- "pileup": Generated pileup (reference_length, 4)
"""
Input Format¤
Reads¤
One-hot encoded reads with shape (num_reads, read_length, 4):
# Hard one-hot (typical)
read_indices = jnp.array([0, 1, 2, 3, 0, 1]) # ACGTAC
read = jnp.eye(4)[read_indices] # (6, 4)
# Soft one-hot (for training)
read_soft = jax.nn.softmax(logits, axis=-1) # (length, 4)
Positions¤
Integer starting positions for each read:
positions = jnp.array([10, 25, 42, ...]) # (num_reads,)
# Read i covers positions[i] to positions[i] + read_length - 1
Quality Scores¤
Phred quality scores for each base:
quality = jnp.array([
[30, 35, 28, 40, ...], # Read 1 quality scores
[25, 30, 32, 35, ...], # Read 2 quality scores
...
]) # (num_reads, read_length)
Output Format¤
The output pileup has shape (reference_length, 4):
pileup = result['pileup']
# At each position, pileup[i] is a probability distribution
# pileup[i, 0] = P(A at position i)
# pileup[i, 1] = P(C at position i)
# pileup[i, 2] = P(G at position i)
# pileup[i, 3] = P(T at position i)
# Sum to 1 at each position
assert jnp.allclose(pileup.sum(axis=-1), 1.0)
Quality Weighting¤
When use_quality_weights=True, bases are weighted by quality:
| Phred Score | Error Rate | Weight |
|---|---|---|
| 10 | 10% | 0.90 |
| 20 | 1% | 0.99 |
| 30 | 0.1% | 0.999 |
| 40 | 0.01% | 0.9999 |
# Quality weighting in action
def quality_to_weight(phred_scores):
p_error = jnp.power(10.0, -phred_scores / 10.0)
return 1.0 - p_error
# High quality base contributes more
weight_q30 = quality_to_weight(30) # 0.999
weight_q10 = quality_to_weight(10) # 0.9
Advanced Usage¤
Variant Detection from Pileup¤
def detect_variants(pileup, reference):
"""Identify positions where pileup differs from reference.
Args:
pileup: (reference_length, 4) nucleotide distributions
reference: (reference_length, 4) one-hot reference
Returns:
Variant scores at each position
"""
# Probability of reference base
ref_prob = (pileup * reference).sum(axis=-1)
# Variant probability = 1 - reference probability
variant_prob = 1.0 - ref_prob
return variant_prob
variant_scores = detect_variants(result['pileup'], reference_onehot)
high_confidence_variants = variant_scores > 0.3
Coverage Analysis¤
config = PileupConfig(
reference_length=1000,
return_coverage=True,
use_quality_weights=False,
)
pileup_op = DifferentiablePileup(config)
result, _, _ = pileup_op.apply(
{"reads": reads, "positions": positions, "quality": quality},
{},
None,
)
coverage = result["coverage"].squeeze(-1)
low_coverage_mask = coverage < 5.0
Gradient-Based Analysis¤
import jax
def pileup_entropy(reads, positions, quality, config):
"""Compute entropy of pileup (measure of uncertainty)."""
pileup_op = DifferentiablePileup(config)
data = {"reads": reads, "positions": positions, "quality": quality}
result, _, _ = pileup_op.apply(data, {}, None)
pileup = result['pileup']
# Entropy: -sum(p * log(p))
entropy = -(pileup * jnp.log(pileup + 1e-8)).sum(axis=-1)
return entropy.mean()
# Gradient of entropy w.r.t. quality scores
grad_fn = jax.grad(pileup_entropy, argnums=2)
quality_grads = grad_fn(reads, positions, quality, config)
# Which bases' quality scores most affect uncertainty?
most_influential = jnp.unravel_index(
jnp.argmax(jnp.abs(quality_grads)),
quality_grads.shape
)
Integration with Variant Calling¤
from diffbio.operators.variant import VariantClassifier, VariantClassifierConfig
# Pileup generation
pileup_config = PileupConfig(reference_length=1000)
pileup_op = DifferentiablePileup(pileup_config)
# Variant classification
classifier_config = VariantClassifierConfig(hidden_dims=[64, 32])
classifier = VariantClassifier(classifier_config)
def variant_calling_pipeline(reads, positions, quality):
# Generate pileup
data = {"reads": reads, "positions": positions, "quality": quality}
pileup_result, _, _ = pileup_op.apply(data, {}, None)
# Classify variants
pileup = pileup_result['pileup']
variant_probs = classifier(pileup)
return variant_probs # (reference_length, 3) for ref/het/hom
# End-to-end gradient computation
loss_fn = lambda r, p, q, targets: cross_entropy(
variant_calling_pipeline(r, p, q), targets
)
grads = jax.grad(loss_fn)(reads, positions, quality, true_variants)
Implementation Details¤
Aggregation Algorithm¤
The pileup uses JAX's segment_sum for efficient aggregation:
# For each base position in each read, compute absolute reference position
absolute_positions = positions[:, None] + jnp.arange(read_length)
# Flatten and aggregate
flat_positions = absolute_positions.reshape(-1)
flat_reads = reads.reshape(-1, 4) * weights.reshape(-1, 1)
# Aggregate at each reference position
pileup = jax.ops.segment_sum(
flat_reads,
flat_positions,
num_segments=reference_length
)
Normalization¤
After aggregation, the pileup is normalized to probability distributions:
# Normalize by coverage
coverage = jax.ops.segment_sum(weights, positions, num_segments=ref_len)
pileup_normalized = pileup / jnp.maximum(coverage, 1e-8)
# Apply softmax for valid distribution
pileup_final = jax.nn.softmax(pileup_normalized / temperature, axis=-1)
Out-of-Bounds Handling¤
Reads extending beyond reference boundaries are automatically clipped:
# Mask positions outside valid range
in_bounds = (positions >= 0) & (positions < reference_length)
weights = weights * in_bounds.astype(jnp.float32)
Performance Considerations¤
Memory¤
| Component | Memory |
|---|---|
| Input reads | O(num_reads × read_length × 4) |
| Intermediate | O(num_reads × read_length) |
| Output pileup | O(reference_length × 4) |
For large datasets, process in chunks:
def chunked_pileup(reads_list, positions_list, quality_list, config, chunk_size=1000):
pileups = []
for i in range(0, len(reads_list), chunk_size):
chunk_reads = jnp.stack(reads_list[i:i+chunk_size])
chunk_pos = jnp.stack(positions_list[i:i+chunk_size])
chunk_qual = jnp.stack(quality_list[i:i+chunk_size])
data = {"reads": chunk_reads, "positions": chunk_pos, "quality": chunk_qual}
result, _, _ = pileup_op.apply(data, {}, None)
pileups.append(result['pileup'])
# Combine chunks (average)
return jnp.mean(jnp.stack(pileups), axis=0)
GPU Acceleration¤
# JIT compile for GPU
@jax.jit
def fast_pileup(reads, positions, quality):
data = {"reads": reads, "positions": positions, "quality": quality}
result, _, _ = pileup_op.apply(data, {}, None)
return result['pileup']
# First call compiles, subsequent calls are fast
pileup = fast_pileup(reads, positions, quality)
DeepVariant-Style Pileup Images¤
For CNN-based variant calling, DiffBio provides DeepVariantStylePileup that generates multi-channel pileup images compatible with DeepVariant's architecture.
DeepVariant Pileup Fully Differentiable
Overview¤
DeepVariant uses pileup "images" where each aligned read is a row and each column is a genomic position. Multiple channels encode different features:
| Channel | Description | Values |
|---|---|---|
| Base (A/C/G/T) | 4 one-hot channels for nucleotide identity | 0 or 1 |
| Base Quality | Phred score normalized to [0,1] | [0, 1] |
| Mapping Quality | MAPQ normalized to [0,1] | [0, 1] |
| Strand | Read orientation | 0=forward, 1=reverse |
| Supports Variant | Soft mismatch indicator | [0, 1] |
| Differs from Ref | Reference mismatch | [0, 1] |
Quick Start¤
import jax
import jax.numpy as jnp
from diffbio.operators.variant import DeepVariantStylePileup, DeepVariantPileupConfig
# Configure pileup
config = DeepVariantPileupConfig(
window_size=221, # Standard DeepVariant window
max_reads=100, # Maximum reads per pileup
)
# Create operator
pileup_op = DeepVariantStylePileup(config)
# Prepare data
num_reads = 30
read_length = 50
window_size = 221
reads = jax.nn.softmax(
jax.random.uniform(jax.random.PRNGKey(0), (num_reads, read_length, 4)),
axis=-1
)
reference = jax.nn.softmax(
jax.random.uniform(jax.random.PRNGKey(1), (window_size, 4)),
axis=-1
)
base_qualities = jax.random.uniform(jax.random.PRNGKey(2), (num_reads, read_length)) * 40
mapping_qualities = jax.random.uniform(jax.random.PRNGKey(3), (num_reads,)) * 60
strands = (jax.random.uniform(jax.random.PRNGKey(4), (num_reads,)) > 0.5).astype(jnp.float32)
positions = jax.random.randint(jax.random.PRNGKey(5), (num_reads,), 0, window_size - read_length)
# Generate pileup image
data = {
"reads": reads,
"reference": reference,
"base_qualities": base_qualities,
"mapping_qualities": mapping_qualities,
"strands": strands,
"positions": positions,
}
result, _, _ = pileup_op.apply(data, {}, None)
print(f"Pileup image shape: {result['pileup_image'].shape}") # (100, 221, 9)
Configuration¤
DeepVariantPileupConfig¤
| Parameter | Type | Default | Description |
|---|---|---|---|
window_size |
int | 221 | Width of pileup in base pairs |
max_reads |
int | 100 | Maximum reads (image height) |
channels |
tuple[str, ...] | ("base", "base_quality", "mapping_quality", "strand", "supports_variant", "differs_from_ref") |
Ordered DeepVariant channels to emit |
quality_max |
float | 40.0 | Max quality for normalization |
mapq_max |
float | 60.0 | Max MAPQ for normalization |
temperature |
float | 1.0 | Temperature for soft operations |
Channel Details¤
Base Identity Channels (4 channels)¤
Each base position is encoded as a one-hot vector over {A, C, G, T}:
# Example: Read with sequence "ACG"
# Position 0: [1, 0, 0, 0] # A
# Position 1: [0, 1, 0, 0] # C
# Position 2: [0, 0, 1, 0] # G
Quality Channels¤
Quality scores are normalized to [0, 1]:
# Base quality: normalized by quality_max (default 40)
normalized_bq = base_quality / 40.0
# Mapping quality: normalized by mapq_max (default 60)
normalized_mapq = mapping_quality / 60.0
Variant Support Channel¤
Uses soft comparison for differentiability:
# Soft mismatch: 1 - dot product of one-hot vectors
match_score = jnp.sum(read_base * ref_base)
mismatch = 1.0 - match_score
Integration with CNN Classifiers¤
from flax import nnx
from diffbio.operators.variant import (
DeepVariantStylePileup,
DeepVariantPileupConfig,
CNNVariantClassifier,
CNNVariantClassifierConfig,
)
# Create pileup generator
pileup_config = DeepVariantPileupConfig(window_size=101, max_reads=50)
pileup_op = DeepVariantStylePileup(pileup_config)
# Create CNN classifier
classifier_config = CNNVariantClassifierConfig(
num_classes=3, # ref/het/hom_alt
window_size=101,
num_channels=pileup_op.num_channels,
)
classifier = CNNVariantClassifier(classifier_config, rngs=nnx.Rngs(42))
# End-to-end pipeline
def variant_pipeline(reads, reference, base_qualities, mapping_qualities, strands, positions):
# Generate pileup image
pileup_data = {
"reads": reads,
"reference": reference,
"base_qualities": base_qualities,
"mapping_qualities": mapping_qualities,
"strands": strands,
"positions": positions,
}
pileup_result, _, _ = pileup_op.apply(pileup_data, {}, None)
# Classify variants
# Reshape for batch processing: (1, height, width, channels)
pileup_batch = pileup_result["pileup_image"][None, ...]
classifier_data = {"pileup_tensor": pileup_batch}
result, _, _ = classifier.apply(classifier_data, {}, None)
return result["predictions"] # (1, 3)
Differentiability¤
The entire pileup generation is differentiable, enabling end-to-end training:
import jax
def loss_fn(pileup_op, data, targets):
result, _, _ = pileup_op.apply(data, {}, None)
pileup_image = result["pileup_image"]
# Example: minimize difference from target
return jnp.mean((pileup_image - targets) ** 2)
# Compute gradients
grads = jax.grad(loss_fn)(pileup_op, data, target_image)
Performance Tips¤
- JIT Compilation: Always JIT compile for production use:
@jax.jit
def fast_pileup(data):
result, _, _ = pileup_op.apply(data, {}, None)
return result["pileup_image"]
- Channel Selection: Disable unused channels to reduce memory:
# Minimal configuration for base-only analysis
config = DeepVariantPileupConfig(
channels=("base",),
)
# Only 4 channels instead of 9
- Window Size: Match window size to your variant calling context (221 for SNPs, larger for indels).
References¤
-
Li, H. et al. (2009). "The Sequence Alignment/Map format and SAMtools."
-
Poplin, R. et al. (2018). "A universal SNP and small-indel variant caller using deep neural networks."
-
Google DeepVariant (2020). "Looking through DeepVariant's eyes." https://google.github.io/deepvariant/posts/2020-02-20-looking-through-deepvariants-eyes/