Pileup API¤
Differentiable pileup generation for variant calling.
DifferentiablePileup¤
diffbio.operators.variant.pileup.DifferentiablePileup
¤
DifferentiablePileup(
config: PileupConfig,
*,
rngs: Rngs | None = None,
name: str | None = None,
)
Bases: TemperatureOperator
Differentiable pileup generator.
Aggregates aligned reads into a position-wise nucleotide distribution that can be used for variant calling. Unlike traditional pileup which simply counts bases, this implementation uses soft weighting that allows gradients to flow through.
Inherits from TemperatureOperator to get:
- _temperature property for temperature-controlled smoothing
- soft_max() for logsumexp-based smooth maximum
- soft_argmax() for soft position selection
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
PileupConfig
|
Pileup configuration. |
required |
rngs
|
Rngs | None
|
Flax NNX random number generators. |
None
|
name
|
str | None
|
Optional operator name. |
None
|
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
PileupConfig
|
Pileup configuration. |
required |
rngs
|
Rngs | None
|
Random number generators (optional). |
None
|
name
|
str | None
|
Optional operator name. |
None
|
compute_pileup
¤
compute_pileup(
reads: Float[Array, "num_reads read_length 4"],
positions: Int[Array, num_reads],
quality: Float[Array, "num_reads read_length"],
reference_length: int,
) -> dict[str, Float[Array, ...]]
Generate pileup from aligned reads.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
reads
|
Float[Array, 'num_reads read_length 4']
|
One-hot encoded reads (num_reads, read_length, 4). |
required |
positions
|
Int[Array, num_reads]
|
Starting position of each read (num_reads,). |
required |
quality
|
Float[Array, 'num_reads read_length']
|
Quality scores for each base (num_reads, read_length). |
required |
reference_length
|
int
|
Length of reference sequence. |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Float[Array, ...]]
|
Dictionary containing: |
dict[str, Float[Array, ...]]
|
|
dict[str, Float[Array, ...]]
|
|
dict[str, Float[Array, ...]]
|
|
apply
¤
apply(
data: PyTree,
state: PyTree,
metadata: dict[str, Any] | None,
random_params: Any = None,
stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]
Apply pileup generation to read data.
This method implements the OperatorModule interface for batch processing. It expects data containing reads and their positions, and returns pileup.
Note: reference_length is taken from config (not data) because it must be static for JAX's segment_sum. All reads in a batch must align to the same reference. Output preserves input keys for Datarax vmap compatibility.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
PyTree
|
Dictionary containing: - "reads": One-hot encoded reads (num_reads, read_length, 4) - "positions": Starting position of each read (num_reads,) - "quality": Quality scores for each base (num_reads, read_length) |
required |
state
|
PyTree
|
Element state (passed through unchanged) |
required |
metadata
|
dict[str, Any] | None
|
Element metadata (passed through unchanged) |
required |
random_params
|
Any
|
Not used (deterministic operator) |
None
|
stats
|
dict[str, Any] | None
|
Not used |
None
|
Returns:
| Type | Description |
|---|---|
tuple[PyTree, PyTree, dict[str, Any] | None]
|
Tuple of (transformed_data, state, metadata): - transformed_data contains input data plus pileup array - state is passed through unchanged - metadata is passed through unchanged |
PileupConfig¤
diffbio.operators.variant.pileup.PileupConfig
dataclass
¤
PileupConfig(
temperature: float = DEFAULT_TEMPERATURE,
learnable_temperature: bool = False,
use_quality_weights: bool = True,
reference_length: int = 100,
return_coverage: bool = False,
return_quality: bool = False,
apply_softmax: bool = True,
)
Bases: TemperatureConfig
Configuration for differentiable pileup.
Inherits from TemperatureConfig to get temperature and learnable_temperature fields.
Attributes:
| Name | Type | Description |
|---|---|---|
use_quality_weights |
bool
|
Whether to weight bases by quality scores. |
reference_length |
int
|
Length of reference sequence (required for batch processing). All reads in a batch must align to the same reference length. |
return_coverage |
bool
|
Whether to return coverage channel in output. |
return_quality |
bool
|
Whether to return mean quality channel in output. |
apply_softmax |
bool
|
Whether to apply softmax to final pileup (set False to preserve raw weighted sums, which is better for variant detection). |
Usage Examples¤
Basic Pileup Generation¤
import jax
import jax.numpy as jnp
from diffbio.operators.variant import DifferentiablePileup, PileupConfig
# Configure
config = PileupConfig(
reference_length=100,
use_quality_weights=True,
)
pileup_op = DifferentiablePileup(config)
# Prepare data
num_reads = 20
read_length = 30
reads = jax.nn.softmax(
jax.random.uniform(jax.random.PRNGKey(0), (num_reads, read_length, 4)),
axis=-1
)
positions = jax.random.randint(jax.random.PRNGKey(1), (num_reads,), 0, 70)
quality = jax.random.uniform(jax.random.PRNGKey(2), (num_reads, read_length), 10, 40)
# Generate pileup
result = pileup_op.compute_pileup(reads, positions, quality, 100)
pileup = result["pileup"]
print(f"Pileup shape: {pileup.shape}") # (100, 4)
Datarax Interface¤
data = {
"reads": reads,
"positions": positions,
"quality": quality,
}
result_data, state, metadata = pileup_op.apply(data, {}, None)
pileup = result_data["pileup"]
Gradient Computation¤
import jax
def pileup_loss(pileup_op, reads, positions, quality, target_pileup):
data = {"reads": reads, "positions": positions, "quality": quality}
result, _, _ = pileup_op.apply(data, {}, None)
return jnp.mean((result["pileup"] - target_pileup) ** 2)
grads = jax.grad(pileup_loss)(pileup_op, reads, positions, quality, target)
Input Specifications¤
reads¤
| Property | Value |
|---|---|
| Shape | (num_reads, read_length, 4) |
| Type | Float[Array, ...] |
| Description | One-hot encoded read sequences |
positions¤
| Property | Value |
|---|---|
| Shape | (num_reads,) |
| Type | Int[Array, ...] |
| Description | Starting position of each read |
quality¤
| Property | Value |
|---|---|
| Shape | (num_reads, read_length) |
| Type | Float[Array, ...] |
| Description | Phred quality scores |
Output Specifications¤
pileup¤
| Property | Value |
|---|---|
| Shape | (reference_length, 4) |
| Type | Float[Array, ...] |
| Description | Nucleotide distribution at each position |