Skip to content

Preprocessing Pipeline API¤

End-to-end differentiable preprocessing pipeline.

PreprocessingPipeline¤

diffbio.pipelines.preprocessing.PreprocessingPipeline ¤

PreprocessingPipeline(
    config: PreprocessingPipelineConfig,
    *,
    rngs: Rngs,
    name: str | None = None,
)

Bases: OperatorModule

End-to-end differentiable preprocessing pipeline.

This pipeline processes sequencing reads through multiple preprocessing steps:

Input data structure
  • reads: Float[Array, "num_reads read_length 4"] - One-hot encoded reads
  • quality: Float[Array, "num_reads read_length"] - Base quality scores

Output data structure (adds): - preprocessed_reads: Float[Array, "num_reads read_length 4"] - Processed reads - preprocessed_quality: Float[Array, "num_reads read_length"] - Processed quality - read_weights: Float[Array, "num_reads"] - Read uniqueness weights

The pipeline is fully differentiable, supporting gradient-based training to optimize all preprocessing components jointly.

Example
config = PreprocessingPipelineConfig(read_length=150)
pipeline = PreprocessingPipeline(config, rngs=nnx.Rngs(42))
result, state, meta = pipeline.apply(data, {}, None)
processed = result["preprocessed_reads"]

Parameters:

Name Type Description Default
config PreprocessingPipelineConfig

Pipeline configuration.

required
rngs Rngs

Random number generators for parameter initialization.

required
name str | None

Optional name for the pipeline.

None

apply ¤

apply(
    data: dict[str, Array],
    state: dict[str, Any],
    metadata: dict[str, Any] | None,
    random_params: Any = None,
    stats: dict[str, Any] | None = None,
) -> tuple[
    dict[str, Array], dict[str, Any], dict[str, Any] | None
]

Apply the full preprocessing pipeline to reads.

Parameters:

Name Type Description Default
data dict[str, Array]

Input data containing: - reads: Float[Array, "num_reads read_length 4"] - quality: Float[Array, "num_reads read_length"]

required
state dict[str, Any]

Element state (passed through).

required
metadata dict[str, Any] | None

Element metadata (passed through).

required
random_params Any

Not used (deterministic pipeline).

None
stats dict[str, Any] | None

Optional statistics dict.

None

Returns:

Type Description
dict[str, Array]

Tuple of (output_data, state, metadata) where output_data contains

dict[str, Any]

all input keys plus preprocessed outputs.

PreprocessingPipelineConfig¤

diffbio.pipelines.preprocessing.PreprocessingPipelineConfig dataclass ¤

PreprocessingPipelineConfig(
    read_length: int = 150,
    adapter_sequence: str = "AGATCGGAAGAG",
    quality_threshold: float = 20.0,
    adapter_match_threshold: float = 0.8,
    adapter_temperature: float = 1.0,
    duplicate_similarity_threshold: float = 0.95,
    error_correction_window: int = 11,
    error_correction_hidden_dim: int = 64,
    enable_adapter_removal: bool = True,
    enable_duplicate_weighting: bool = True,
    enable_error_correction: bool = True,
)

Bases: OperatorConfig

Configuration for the preprocessing pipeline.

Attributes:

Name Type Description
read_length int

Expected read length for initialization.

adapter_sequence str

Adapter sequence to remove (Illumina universal default).

quality_threshold float

Initial quality score threshold for filtering.

adapter_match_threshold float

Threshold for adapter matching.

adapter_temperature float

Temperature for soft adapter trimming.

duplicate_similarity_threshold float

Similarity threshold for duplicate detection.

error_correction_window int

Window size for error correction.

error_correction_hidden_dim int

Hidden dimension for error correction network.

enable_adapter_removal bool

Whether to enable adapter removal step.

enable_duplicate_weighting bool

Whether to enable duplicate weighting step.

enable_error_correction bool

Whether to enable error correction step.

Usage Examples¤

Basic Usage¤

from flax import nnx
from diffbio.pipelines import PreprocessingPipeline, PreprocessingPipelineConfig

config = PreprocessingPipelineConfig(
    quality_threshold=20.0,
    enable_adapter_removal=True,
    enable_duplicate_weighting=True,
    enable_error_correction=True,
)

pipeline = PreprocessingPipeline(config, rngs=nnx.Rngs(42))

data = {
    "reads": read_sequences,    # (n_reads, read_length, 4)
    "quality": quality_scores,  # (n_reads, read_length)
}
result, _, _ = pipeline.apply(data, {}, None)

preprocessed = result["preprocessed_reads"]
weights = result["read_weights"]

Training¤

from flax import nnx

def loss_fn(pipeline, reads, quality, target):
    data = {"reads": reads, "quality": quality}
    result, _, _ = pipeline.apply(data, {}, None)
    return jnp.mean((result["preprocessed_reads"] - target) ** 2)

grads = nnx.grad(loss_fn)(pipeline, reads, quality, corrected_target)