Quality Filter API¤
Differentiable quality filter for sequence preprocessing.
DifferentiableQualityFilter¤
diffbio.operators.quality_filter.DifferentiableQualityFilter
¤
DifferentiableQualityFilter(
config: QualityFilterConfig,
*,
rngs: Rngs | None = None,
name: str | None = None,
)
Bases: OperatorModule
Differentiable quality filter for DNA/RNA sequences.
This operator applies soft quality filtering using a sigmoid function to weight sequence positions by their quality scores. High-quality positions (above threshold) pass through with high weight, while low-quality positions are down-weighted.
The threshold is a learnable parameter that can be optimized end-to-end with the rest of the pipeline.
Formula
retention_weight = sigmoid(quality_score - threshold) filtered_sequence = sequence * retention_weight
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
QualityFilterConfig
|
QualityFilterConfig with initial threshold |
required |
rngs
|
Rngs | None
|
Flax NNX random number generators |
None
|
Example
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
QualityFilterConfig
|
Quality filter configuration |
required |
rngs
|
Rngs | None
|
Random number generators (optional for deterministic ops) |
None
|
name
|
str | None
|
Optional operator name |
None
|
apply
¤
apply(
data: PyTree,
state: PyTree,
metadata: dict[str, Any] | None,
random_params: Any = None,
stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]
Apply soft quality filtering to sequence data.
This method applies a differentiable quality filter that weights each position by sigmoid(quality - threshold). High quality positions retain most of their value, while low quality positions are down-weighted.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
PyTree
|
Dictionary containing: - "sequence": One-hot encoded sequence (length, alphabet_size) - "quality_scores": Phred quality scores (length,) |
required |
state
|
PyTree
|
Element state (passed through unchanged) |
required |
metadata
|
dict[str, Any] | None
|
Element metadata (passed through unchanged) |
required |
random_params
|
Any
|
Not used (deterministic operator) |
None
|
stats
|
dict[str, Any] | None
|
Not used |
None
|
Returns:
| Type | Description |
|---|---|
tuple[PyTree, PyTree, dict[str, Any] | None]
|
Tuple of (transformed_data, state, metadata): - transformed_data contains weighted sequence and original quality - state is passed through unchanged - metadata is passed through unchanged |
QualityFilterConfig¤
diffbio.operators.quality_filter.QualityFilterConfig
dataclass
¤
Bases: OperatorConfig
Configuration for DifferentiableQualityFilter.
Attributes:
| Name | Type | Description |
|---|---|---|
initial_threshold |
float
|
Initial Phred quality score threshold. Positions with quality below this are down-weighted. Default is 20.0 (1% error rate). |
Usage Examples¤
Basic Quality Filtering¤
import jax.numpy as jnp
from diffbio.operators import DifferentiableQualityFilter, QualityFilterConfig
# Configure
config = QualityFilterConfig(initial_threshold=20.0)
filter_op = DifferentiableQualityFilter(config)
# Prepare data
sequence = jnp.eye(4)[jnp.array([0, 1, 2, 3, 0, 1])] # ACGTAC
quality = jnp.array([30.0, 25.0, 10.0, 35.0, 15.0, 28.0])
# Apply filter
data = {"sequence": sequence, "quality_scores": quality}
result, _, _ = filter_op.apply(data, {}, None)
print(f"Original sum: {sequence.sum():.2f}")
print(f"Filtered sum: {result['sequence'].sum():.2f}")
Access Threshold¤
# Get current threshold
threshold = filter_op.threshold[...]
print(f"Threshold: {threshold}")
# Update threshold
filter_op.threshold[...] = 25.0
Gradient Computation¤
import jax
def filter_loss(filter_op, sequence, quality):
data = {"sequence": sequence, "quality_scores": quality}
result, _, _ = filter_op.apply(data, {}, None)
return result["sequence"].sum()
# Gradient w.r.t. threshold
grads = jax.grad(filter_loss)(filter_op, sequence, quality)
print(f"Threshold gradient: {grads.threshold}")
Filter Response¤
The filter applies sigmoid weighting:
| Quality vs Threshold | Retention Weight |
|---|---|
| Q << t | ~0 (filtered) |
| Q = t | 0.5 |
| Q >> t | ~1 (retained) |
Input Specifications¤
sequence¤
| Property | Value |
|---|---|
| Shape | (length, alphabet_size) |
| Type | Float[Array, ...] |
| Description | One-hot encoded sequence |
quality_scores¤
| Property | Value |
|---|---|
| Shape | (length,) |
| Type | Float[Array, ...] |
| Description | Phred quality scores |
Output Specifications¤
sequence¤
| Property | Value |
|---|---|
| Shape | (length, alphabet_size) |
| Type | Float[Array, ...] |
| Description | Quality-weighted sequence |
quality_scores¤
| Property | Value |
|---|---|
| Shape | (length,) |
| Type | Float[Array, ...] |
| Description | Original quality scores (preserved) |