Preprocessing Pipeline API¤
End-to-end differentiable preprocessing pipeline.
PreprocessingPipeline¤
diffbio.pipelines.preprocessing.PreprocessingPipeline
¤
PreprocessingPipeline(
config: PreprocessingPipelineConfig,
*,
rngs: Rngs,
name: str | None = None,
)
Bases: OperatorModule
End-to-end differentiable preprocessing pipeline.
This pipeline processes sequencing reads through multiple preprocessing steps:
Input data structure
- reads: Float[Array, "num_reads read_length 4"] - One-hot encoded reads
- quality: Float[Array, "num_reads read_length"] - Base quality scores
Output data structure (adds): - preprocessed_reads: Float[Array, "num_reads read_length 4"] - Processed reads - preprocessed_quality: Float[Array, "num_reads read_length"] - Processed quality - read_weights: Float[Array, "num_reads"] - Read uniqueness weights
The pipeline is fully differentiable, supporting gradient-based training to optimize all preprocessing components jointly.
Example
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
PreprocessingPipelineConfig
|
Pipeline configuration. |
required |
rngs
|
Rngs
|
Random number generators for parameter initialization. |
required |
name
|
str | None
|
Optional name for the pipeline. |
None
|
apply
¤
apply(
data: dict[str, Array],
state: dict[str, Any],
metadata: dict[str, Any] | None,
random_params: Any = None,
stats: dict[str, Any] | None = None,
) -> tuple[
dict[str, Array], dict[str, Any], dict[str, Any] | None
]
Apply the full preprocessing pipeline to reads.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
dict[str, Array]
|
Input data containing: - reads: Float[Array, "num_reads read_length 4"] - quality: Float[Array, "num_reads read_length"] |
required |
state
|
dict[str, Any]
|
Element state (passed through). |
required |
metadata
|
dict[str, Any] | None
|
Element metadata (passed through). |
required |
random_params
|
Any
|
Not used (deterministic pipeline). |
None
|
stats
|
dict[str, Any] | None
|
Optional statistics dict. |
None
|
Returns:
| Type | Description |
|---|---|
dict[str, Array]
|
Tuple of (output_data, state, metadata) where output_data contains |
dict[str, Any]
|
all input keys plus preprocessed outputs. |
PreprocessingPipelineConfig¤
diffbio.pipelines.preprocessing.PreprocessingPipelineConfig
dataclass
¤
PreprocessingPipelineConfig(
read_length: int = 150,
adapter_sequence: str = "AGATCGGAAGAG",
quality_threshold: float = 20.0,
adapter_match_threshold: float = 0.8,
adapter_temperature: float = 1.0,
duplicate_similarity_threshold: float = 0.95,
error_correction_window: int = 11,
error_correction_hidden_dim: int = 64,
enable_adapter_removal: bool = True,
enable_duplicate_weighting: bool = True,
enable_error_correction: bool = True,
)
Bases: OperatorConfig
Configuration for the preprocessing pipeline.
Attributes:
| Name | Type | Description |
|---|---|---|
read_length |
int
|
Expected read length for initialization. |
adapter_sequence |
str
|
Adapter sequence to remove (Illumina universal default). |
quality_threshold |
float
|
Initial quality score threshold for filtering. |
adapter_match_threshold |
float
|
Threshold for adapter matching. |
adapter_temperature |
float
|
Temperature for soft adapter trimming. |
duplicate_similarity_threshold |
float
|
Similarity threshold for duplicate detection. |
error_correction_window |
int
|
Window size for error correction. |
error_correction_hidden_dim |
int
|
Hidden dimension for error correction network. |
enable_adapter_removal |
bool
|
Whether to enable adapter removal step. |
enable_duplicate_weighting |
bool
|
Whether to enable duplicate weighting step. |
enable_error_correction |
bool
|
Whether to enable error correction step. |
Usage Examples¤
Basic Usage¤
from flax import nnx
from diffbio.pipelines import PreprocessingPipeline, PreprocessingPipelineConfig
config = PreprocessingPipelineConfig(
quality_threshold=20.0,
enable_adapter_removal=True,
enable_duplicate_weighting=True,
enable_error_correction=True,
)
pipeline = PreprocessingPipeline(config, rngs=nnx.Rngs(42))
data = {
"reads": read_sequences, # (n_reads, read_length, 4)
"quality": quality_scores, # (n_reads, read_length)
}
result, _, _ = pipeline.apply(data, {}, None)
preprocessed = result["preprocessed_reads"]
weights = result["read_weights"]