Skip to content

Mapping Operators API¤

Differentiable operators for read mapping using neural networks.

NeuralReadMapper¤

diffbio.operators.mapping.neural_mapper.NeuralReadMapper ¤

NeuralReadMapper(
    config: NeuralReadMapperConfig,
    *,
    rngs: Rngs | None = None,
    name: str | None = None,
)

Bases: TemperatureOperator

Neural network-based read mapper.

This operator uses cross-attention between read and reference embeddings to compute soft alignment scores, enabling fully differentiable read mapping.

Algorithm: 1. Encode read and reference with positional embeddings 2. Apply transformer layers with cross-attention 3. Compute position-wise alignment scores 4. Apply softmax for position probabilities 5. Compute mapping quality from confidence

Inherits from TemperatureOperator to get:

  • _temperature property for temperature-controlled smoothing
  • soft_max() for logsumexp-based smooth maximum
  • soft_argmax() for soft position selection

Parameters:

Name Type Description Default
config NeuralReadMapperConfig

NeuralReadMapperConfig with model parameters.

required
rngs Rngs | None

Flax NNX random number generators.

None
name str | None

Optional operator name.

None
Example
config = NeuralReadMapperConfig(embedding_dim=64)
mapper = NeuralReadMapper(config, rngs=nnx.Rngs(42))
data = {"read": read_onehot, "reference": ref_onehot}
result, state, meta = mapper.apply(data, {}, None)

Parameters:

Name Type Description Default
config NeuralReadMapperConfig

Mapper configuration.

required
rngs Rngs | None

Random number generators for initialization.

None
name str | None

Optional operator name.

None

apply ¤

apply(
    data: PyTree,
    state: PyTree,
    metadata: dict[str, Any] | None,
    random_params: Any = None,
    stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]

Apply neural read mapping.

Parameters:

Name Type Description Default
data PyTree

Dictionary containing: - "read": One-hot encoded read (batch, read_len, 4) - "reference": One-hot encoded reference (batch, ref_len, 4)

required
state PyTree

Element state (passed through unchanged)

required
metadata dict[str, Any] | None

Element metadata (passed through unchanged)

required
random_params Any

Not used

None
stats dict[str, Any] | None

Not used

None

Returns:

Type Description
tuple[PyTree, PyTree, dict[str, Any] | None]

Tuple of (transformed_data, state, metadata): - transformed_data contains:

- "read": Original read
- "reference": Original reference
- "alignment_scores": Scores for each reference position
- "position_probs": Softmax probabilities over positions
- "best_position": Most likely mapping position
- "mapping_quality": Confidence score for mapping
  • state is passed through unchanged
  • metadata is passed through unchanged

NeuralReadMapperConfig¤

diffbio.operators.mapping.neural_mapper.NeuralReadMapperConfig dataclass ¤

NeuralReadMapperConfig(
    temperature: float = 1.0,
    learnable_temperature: bool = False,
    read_length: int = 150,
    reference_window: int = 500,
    embedding_dim: int = 64,
    num_heads: int = 4,
    num_layers: int = 4,
    dropout_rate: float = 0.1,
)

Bases: TemperatureConfig

Configuration for NeuralReadMapper.

Attributes:

Name Type Description
read_length int

Expected read length.

reference_window int

Reference window size.

embedding_dim int

Dimension of sequence embeddings.

num_heads int

Number of attention heads.

num_layers int

Number of transformer layers.

dropout_rate float

Dropout rate for regularization.

temperature float

Temperature for softmax operations.

learnable_temperature class-attribute instance-attribute ¤

learnable_temperature: bool = False

Usage Examples¤

Neural Read Mapping¤

from flax import nnx
from diffbio.operators.mapping import NeuralReadMapper, NeuralReadMapperConfig

config = NeuralReadMapperConfig(
    read_length=150,
    reference_window=1000,
    embedding_dim=128,
    num_layers=4,
    num_heads=8,
)
mapper = NeuralReadMapper(config, rngs=nnx.Rngs(42))

data = {
    "reads": read_sequences,       # (n_reads, read_length, alphabet_size)
    "reference": reference_seq,    # (ref_length, alphabet_size)
}
result, _, _ = mapper.apply(data, {}, None)
positions = result["positions"]
mapping_scores = result["scores"]

Batch Read Mapping¤

# Process multiple reads at once
reads = jnp.stack([read1, read2, read3])  # (3, read_length, 4)

data = {"reads": reads, "reference": reference}
result, _, _ = mapper.apply(data, {}, None)

# Get mapping positions for all reads
all_positions = result["positions"]  # (3,)