Preprocessing Operators API¤
Differentiable preprocessing operators for read quality control, adapter removal, and error correction.
SoftAdapterRemoval¤
diffbio.operators.preprocessing.adapter_removal.SoftAdapterRemoval
¤
SoftAdapterRemoval(
config: AdapterRemovalConfig,
*,
rngs: Rngs | None = None,
name: str | None = None,
)
Bases: TemperatureOperator
Differentiable adapter removal for sequencing reads.
This operator performs soft adapter trimming using a differentiable approach. It finds potential adapter matches at the 3' end of reads and applies sigmoid-weighted trimming that maintains gradient flow.
The algorithm: 1. Compute soft alignment scores between sequence suffix and adapter 2. Find the soft trim position using weighted position averaging 3. Apply sigmoid-weighted retention to each position
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
AdapterRemovalConfig
|
AdapterRemovalConfig with adapter parameters. |
required |
rngs
|
Rngs | None
|
Flax NNX random number generators (optional). |
None
|
name
|
str | None
|
Optional operator name. |
None
|
Example
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
AdapterRemovalConfig
|
Adapter removal configuration. |
required |
rngs
|
Rngs | None
|
Random number generators (optional). |
None
|
name
|
str | None
|
Optional operator name. |
None
|
apply
¤
apply(
data: PyTree,
state: PyTree,
metadata: dict[str, Any] | None,
random_params: Any = None,
stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]
Apply soft adapter removal to sequence data.
This method finds potential adapter sequences and applies differentiable soft trimming to remove them while maintaining gradient flow.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
PyTree
|
Dictionary containing: - "sequence": One-hot encoded sequence (length, alphabet_size) - "quality_scores": Phred quality scores (length,) |
required |
state
|
PyTree
|
Element state (passed through unchanged) |
required |
metadata
|
dict[str, Any] | None
|
Element metadata (passed through unchanged) |
required |
random_params
|
Any
|
Not used (deterministic operator) |
None
|
stats
|
dict[str, Any] | None
|
Not used |
None
|
Returns:
| Type | Description |
|---|---|
tuple[PyTree, PyTree, dict[str, Any] | None]
|
Tuple of (transformed_data, state, metadata): - transformed_data contains:
|
AdapterRemovalConfig¤
diffbio.operators.preprocessing.adapter_removal.AdapterRemovalConfig
dataclass
¤
AdapterRemovalConfig(
adapter_sequence: str = "AGATCGGAAGAG",
temperature: float = 1.0,
learnable_temperature: bool = True,
match_threshold: float = 0.5,
min_overlap: int = 6,
)
Bases: OperatorConfig
Configuration for SoftAdapterRemoval.
Attributes:
| Name | Type | Description |
|---|---|---|
adapter_sequence |
str
|
Adapter sequence to remove (default: Illumina universal). |
temperature |
float
|
Temperature for soft matching and trimming. Lower = sharper trimming, Higher = smoother. |
match_threshold |
float
|
Minimum alignment score ratio to consider a match. |
min_overlap |
int
|
Minimum overlap length to consider adapter presence. |
DifferentiableDuplicateWeighting¤
diffbio.operators.preprocessing.duplicate_filter.DifferentiableDuplicateWeighting
¤
DifferentiableDuplicateWeighting(
config: DuplicateWeightingConfig,
*,
rngs: Rngs | None = None,
name: str | None = None,
)
Bases: TemperatureOperator
Differentiable duplicate weighting for sequencing reads.
This operator assigns probabilistic weights to reads based on their uniqueness within a batch. Instead of hard duplicate removal, it down-weights reads that are similar to others, maintaining gradient flow.
The algorithm: 1. Embed sequences using learned convolutional features 2. Compute pairwise soft similarity matrix 3. Compute soft cluster sizes from similarity matrix 4. Assign weights inversely proportional to cluster size
Note: This operator works on batched data where reads can be compared. For single-read processing, it returns weight=1.0.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
DuplicateWeightingConfig
|
DuplicateWeightingConfig with weighting parameters. |
required |
rngs
|
Rngs | None
|
Flax NNX random number generators. |
None
|
name
|
str | None
|
Optional operator name. |
None
|
Example
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
DuplicateWeightingConfig
|
Duplicate weighting configuration. |
required |
rngs
|
Rngs | None
|
Random number generators for initialization. |
None
|
name
|
str | None
|
Optional operator name. |
None
|
apply
¤
apply(
data: PyTree,
state: PyTree,
metadata: dict[str, Any] | None,
random_params: Any = None,
stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]
Apply duplicate weighting to sequence data.
For single sequences, returns weight=1.0. For batched sequences, computes uniqueness-based weights.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
PyTree
|
Dictionary containing: - "sequence": One-hot encoded sequence (length, alphabet_size) or batch (batch, length, alphabet_size) - "quality_scores": Quality scores (length,) or (batch, length) |
required |
state
|
PyTree
|
Element state (passed through unchanged) |
required |
metadata
|
dict[str, Any] | None
|
Element metadata (passed through unchanged) |
required |
random_params
|
Any
|
Not used (deterministic operator) |
None
|
stats
|
dict[str, Any] | None
|
Not used |
None
|
Returns:
| Type | Description |
|---|---|
tuple[PyTree, PyTree, dict[str, Any] | None]
|
Tuple of (transformed_data, state, metadata): - transformed_data contains:
|
DuplicateWeightingConfig¤
diffbio.operators.preprocessing.duplicate_filter.DuplicateWeightingConfig
dataclass
¤
DuplicateWeightingConfig(
temperature: float = 1.0,
learnable_temperature: bool = True,
similarity_threshold: float = 0.9,
embedding_dim: int = 32,
)
Bases: OperatorConfig
Configuration for DifferentiableDuplicateWeighting.
Attributes:
| Name | Type | Description |
|---|---|---|
temperature |
float
|
Temperature for soft similarity computation. Lower = sharper clustering, Higher = smoother. |
similarity_threshold |
float
|
Minimum similarity to consider as duplicate. |
embedding_dim |
int
|
Dimension of learned sequence embedding. |
SoftErrorCorrection¤
diffbio.operators.preprocessing.error_correction.SoftErrorCorrection
¤
SoftErrorCorrection(
config: ErrorCorrectionConfig,
*,
rngs: Rngs | None = None,
name: str | None = None,
)
Bases: TemperatureOperator
Differentiable error correction for sequencing reads.
This operator uses a neural network to refine base calls based on local sequence context and quality scores. It outputs soft base probabilities that maintain gradient flow.
The algorithm: 1. For each position, extract a window of sequence and quality data 2. Pass through MLP to predict corrected base probabilities 3. Output soft one-hot representation blending original and corrected
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
ErrorCorrectionConfig
|
ErrorCorrectionConfig with model parameters. |
required |
rngs
|
Rngs | None
|
Flax NNX random number generators. |
None
|
name
|
str | None
|
Optional operator name. |
None
|
Example
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
ErrorCorrectionConfig
|
Error correction configuration. |
required |
rngs
|
Rngs | None
|
Random number generators for initialization. |
None
|
name
|
str | None
|
Optional operator name. |
None
|
apply
¤
apply(
data: PyTree,
state: PyTree,
metadata: dict[str, Any] | None,
random_params: Any = None,
stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]
Apply error correction to sequence data.
This method corrects each position in the sequence using the neural network model, producing soft corrected base probabilities.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
PyTree
|
Dictionary containing: - "sequence": One-hot encoded sequence (length, alphabet_size) - "quality_scores": Phred quality scores (length,) |
required |
state
|
PyTree
|
Element state (passed through unchanged) |
required |
metadata
|
dict[str, Any] | None
|
Element metadata (passed through unchanged) |
required |
random_params
|
Any
|
Not used (deterministic operator) |
None
|
stats
|
dict[str, Any] | None
|
Not used |
None
|
Returns:
| Type | Description |
|---|---|
tuple[PyTree, PyTree, dict[str, Any] | None]
|
Tuple of (transformed_data, state, metadata): - transformed_data contains:
|
ErrorCorrectionConfig¤
diffbio.operators.preprocessing.error_correction.ErrorCorrectionConfig
dataclass
¤
ErrorCorrectionConfig(
window_size: int = 11,
hidden_dim: int = 64,
num_layers: int = 2,
use_quality: bool = True,
temperature: float = 1.0,
learnable_temperature: bool = True,
)
Bases: OperatorConfig
Configuration for SoftErrorCorrection.
Attributes:
| Name | Type | Description |
|---|---|---|
window_size |
int
|
Size of context window around each position. Must be odd. Default is 11 (5 bases on each side). |
hidden_dim |
int
|
Hidden layer dimension in the MLP. |
num_layers |
int
|
Number of hidden layers in the MLP. |
use_quality |
bool
|
Whether to include quality scores as input. |
temperature |
float
|
Temperature for output softmax. |
Usage Examples¤
Adapter Removal¤
from flax import nnx
from diffbio.operators.preprocessing import SoftAdapterRemoval, AdapterRemovalConfig
config = AdapterRemovalConfig(
adapter_sequence="AGATCGGAAGAG",
temperature=1.0,
match_threshold=0.8,
)
adapter_removal = SoftAdapterRemoval(config, rngs=nnx.Rngs(42))
data = {"sequence": read, "quality_scores": quality}
result, _, _ = adapter_removal.apply(data, {}, None)
trimmed = result["sequence"]
Duplicate Weighting¤
from diffbio.operators.preprocessing import (
DifferentiableDuplicateWeighting,
DuplicateWeightingConfig,
)
config = DuplicateWeightingConfig(embedding_dim=32)
dup_weighting = DifferentiableDuplicateWeighting(config, rngs=nnx.Rngs(42))
data = {"sequences": reads, "quality_scores": quality}
result, _, _ = dup_weighting.apply(data, {}, None)
weights = result["weights"]
Error Correction¤
from diffbio.operators.preprocessing import SoftErrorCorrection, ErrorCorrectionConfig
config = ErrorCorrectionConfig(hidden_dim=64, window_size=5)
error_correction = SoftErrorCorrection(config, rngs=nnx.Rngs(42))
data = {"sequence": read, "quality_scores": quality}
result, _, _ = error_correction.apply(data, {}, None)
corrected = result["sequence"]