Skip to content

Pileup API¤

Differentiable pileup generation for variant calling.

DifferentiablePileup¤

diffbio.operators.variant.pileup.DifferentiablePileup ¤

DifferentiablePileup(
    config: PileupConfig,
    *,
    rngs: Rngs | None = None,
    name: str | None = None,
)

Bases: TemperatureOperator

Differentiable pileup generator.

Aggregates aligned reads into a position-wise nucleotide distribution that can be used for variant calling. Unlike traditional pileup which simply counts bases, this implementation uses soft weighting that allows gradients to flow through.

Inherits from TemperatureOperator to get:

  • _temperature property for temperature-controlled smoothing
  • soft_max() for logsumexp-based smooth maximum
  • soft_argmax() for soft position selection

Parameters:

Name Type Description Default
config PileupConfig

Pileup configuration.

required
rngs Rngs | None

Flax NNX random number generators.

None
name str | None

Optional operator name.

None

Parameters:

Name Type Description Default
config PileupConfig

Pileup configuration.

required
rngs Rngs | None

Random number generators (optional).

None
name str | None

Optional operator name.

None

compute_pileup ¤

compute_pileup(
    reads: Float[Array, "num_reads read_length 4"],
    positions: Int[Array, num_reads],
    quality: Float[Array, "num_reads read_length"],
    reference_length: int,
) -> dict[str, Float[Array, ...]]

Generate pileup from aligned reads.

Parameters:

Name Type Description Default
reads Float[Array, 'num_reads read_length 4']

One-hot encoded reads (num_reads, read_length, 4).

required
positions Int[Array, num_reads]

Starting position of each read (num_reads,).

required
quality Float[Array, 'num_reads read_length']

Quality scores for each base (num_reads, read_length).

required
reference_length int

Length of reference sequence.

required

Returns:

Type Description
dict[str, Float[Array, ...]]

Dictionary containing:

dict[str, Float[Array, ...]]
  • pileup: (reference_length, 4) nucleotide distributions
dict[str, Float[Array, ...]]
  • coverage: (reference_length, 1) read depth at each position (if return_coverage)
dict[str, Float[Array, ...]]
  • mean_quality: (reference_length, 1) mean quality at each position (if return_quality)

apply ¤

apply(
    data: PyTree,
    state: PyTree,
    metadata: dict[str, Any] | None,
    random_params: Any = None,
    stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]

Apply pileup generation to read data.

This method implements the OperatorModule interface for batch processing. It expects data containing reads and their positions, and returns pileup.

Note: reference_length is taken from config (not data) because it must be static for JAX's segment_sum. All reads in a batch must align to the same reference. Output preserves input keys for Datarax vmap compatibility.

Parameters:

Name Type Description Default
data PyTree

Dictionary containing: - "reads": One-hot encoded reads (num_reads, read_length, 4) - "positions": Starting position of each read (num_reads,) - "quality": Quality scores for each base (num_reads, read_length)

required
state PyTree

Element state (passed through unchanged)

required
metadata dict[str, Any] | None

Element metadata (passed through unchanged)

required
random_params Any

Not used (deterministic operator)

None
stats dict[str, Any] | None

Not used

None

Returns:

Type Description
tuple[PyTree, PyTree, dict[str, Any] | None]

Tuple of (transformed_data, state, metadata): - transformed_data contains input data plus pileup array - state is passed through unchanged - metadata is passed through unchanged

PileupConfig¤

diffbio.operators.variant.pileup.PileupConfig dataclass ¤

PileupConfig(
    temperature: float = DEFAULT_TEMPERATURE,
    learnable_temperature: bool = False,
    use_quality_weights: bool = True,
    reference_length: int = 100,
    return_coverage: bool = False,
    return_quality: bool = False,
    apply_softmax: bool = True,
)

Bases: TemperatureConfig

Configuration for differentiable pileup.

Inherits from TemperatureConfig to get temperature and learnable_temperature fields.

Attributes:

Name Type Description
use_quality_weights bool

Whether to weight bases by quality scores.

reference_length int

Length of reference sequence (required for batch processing). All reads in a batch must align to the same reference length.

return_coverage bool

Whether to return coverage channel in output.

return_quality bool

Whether to return mean quality channel in output.

apply_softmax bool

Whether to apply softmax to final pileup (set False to preserve raw weighted sums, which is better for variant detection).

temperature class-attribute instance-attribute ¤

temperature: float = DEFAULT_TEMPERATURE

learnable_temperature class-attribute instance-attribute ¤

learnable_temperature: bool = False

Usage Examples¤

Basic Pileup Generation¤

import jax
import jax.numpy as jnp
from diffbio.operators.variant import DifferentiablePileup, PileupConfig

# Configure
config = PileupConfig(
    reference_length=100,
    use_quality_weights=True,
)
pileup_op = DifferentiablePileup(config)

# Prepare data
num_reads = 20
read_length = 30

reads = jax.nn.softmax(
    jax.random.uniform(jax.random.PRNGKey(0), (num_reads, read_length, 4)),
    axis=-1
)
positions = jax.random.randint(jax.random.PRNGKey(1), (num_reads,), 0, 70)
quality = jax.random.uniform(jax.random.PRNGKey(2), (num_reads, read_length), 10, 40)

# Generate pileup
result = pileup_op.compute_pileup(reads, positions, quality, 100)
pileup = result["pileup"]
print(f"Pileup shape: {pileup.shape}")  # (100, 4)

Datarax Interface¤

data = {
    "reads": reads,
    "positions": positions,
    "quality": quality,
}

result_data, state, metadata = pileup_op.apply(data, {}, None)
pileup = result_data["pileup"]

Gradient Computation¤

import jax

def pileup_loss(pileup_op, reads, positions, quality, target_pileup):
    data = {"reads": reads, "positions": positions, "quality": quality}
    result, _, _ = pileup_op.apply(data, {}, None)
    return jnp.mean((result["pileup"] - target_pileup) ** 2)

grads = jax.grad(pileup_loss)(pileup_op, reads, positions, quality, target)

Input Specifications¤

reads¤

Property Value
Shape (num_reads, read_length, 4)
Type Float[Array, ...]
Description One-hot encoded read sequences

positions¤

Property Value
Shape (num_reads,)
Type Int[Array, ...]
Description Starting position of each read

quality¤

Property Value
Shape (num_reads, read_length)
Type Float[Array, ...]
Description Phred quality scores

Output Specifications¤

pileup¤

Property Value
Shape (reference_length, 4)
Type Float[Array, ...]
Description Nucleotide distribution at each position