Skip to content

Quality Filter API¤

Differentiable quality filter for sequence preprocessing.

DifferentiableQualityFilter¤

diffbio.operators.quality_filter.DifferentiableQualityFilter ¤

DifferentiableQualityFilter(
    config: QualityFilterConfig,
    *,
    rngs: Rngs | None = None,
    name: str | None = None,
)

Bases: OperatorModule

Differentiable quality filter for DNA/RNA sequences.

This operator applies soft quality filtering using a sigmoid function to weight sequence positions by their quality scores. High-quality positions (above threshold) pass through with high weight, while low-quality positions are down-weighted.

The threshold is a learnable parameter that can be optimized end-to-end with the rest of the pipeline.

Formula

retention_weight = sigmoid(quality_score - threshold) filtered_sequence = sequence * retention_weight

Parameters:

Name Type Description Default
config QualityFilterConfig

QualityFilterConfig with initial threshold

required
rngs Rngs | None

Flax NNX random number generators

None
Example
config = QualityFilterConfig(initial_threshold=20.0)
filter_op = DifferentiableQualityFilter(config, rngs=nnx.Rngs(42))
data = {"sequence": encoded_seq, "quality_scores": quality}
filtered_data, state, meta = filter_op.apply(data, {}, None, None)

Parameters:

Name Type Description Default
config QualityFilterConfig

Quality filter configuration

required
rngs Rngs | None

Random number generators (optional for deterministic ops)

None
name str | None

Optional operator name

None

apply ¤

apply(
    data: PyTree,
    state: PyTree,
    metadata: dict[str, Any] | None,
    random_params: Any = None,
    stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]

Apply soft quality filtering to sequence data.

This method applies a differentiable quality filter that weights each position by sigmoid(quality - threshold). High quality positions retain most of their value, while low quality positions are down-weighted.

Parameters:

Name Type Description Default
data PyTree

Dictionary containing: - "sequence": One-hot encoded sequence (length, alphabet_size) - "quality_scores": Phred quality scores (length,)

required
state PyTree

Element state (passed through unchanged)

required
metadata dict[str, Any] | None

Element metadata (passed through unchanged)

required
random_params Any

Not used (deterministic operator)

None
stats dict[str, Any] | None

Not used

None

Returns:

Type Description
tuple[PyTree, PyTree, dict[str, Any] | None]

Tuple of (transformed_data, state, metadata): - transformed_data contains weighted sequence and original quality - state is passed through unchanged - metadata is passed through unchanged

QualityFilterConfig¤

diffbio.operators.quality_filter.QualityFilterConfig dataclass ¤

QualityFilterConfig(
    initial_threshold: float = PHRED_QUALITY_THRESHOLD,
)

Bases: OperatorConfig

Configuration for DifferentiableQualityFilter.

Attributes:

Name Type Description
initial_threshold float

Initial Phred quality score threshold. Positions with quality below this are down-weighted. Default is 20.0 (1% error rate).

Usage Examples¤

Basic Quality Filtering¤

import jax.numpy as jnp
from diffbio.operators import DifferentiableQualityFilter, QualityFilterConfig

# Configure
config = QualityFilterConfig(initial_threshold=20.0)
filter_op = DifferentiableQualityFilter(config)

# Prepare data
sequence = jnp.eye(4)[jnp.array([0, 1, 2, 3, 0, 1])]  # ACGTAC
quality = jnp.array([30.0, 25.0, 10.0, 35.0, 15.0, 28.0])

# Apply filter
data = {"sequence": sequence, "quality_scores": quality}
result, _, _ = filter_op.apply(data, {}, None)

print(f"Original sum: {sequence.sum():.2f}")
print(f"Filtered sum: {result['sequence'].sum():.2f}")

Access Threshold¤

# Get current threshold
threshold = filter_op.threshold[...]
print(f"Threshold: {threshold}")

# Update threshold
filter_op.threshold[...] = 25.0

Gradient Computation¤

import jax

def filter_loss(filter_op, sequence, quality):
    data = {"sequence": sequence, "quality_scores": quality}
    result, _, _ = filter_op.apply(data, {}, None)
    return result["sequence"].sum()

# Gradient w.r.t. threshold
grads = jax.grad(filter_loss)(filter_op, sequence, quality)
print(f"Threshold gradient: {grads.threshold}")

Filter Response¤

The filter applies sigmoid weighting:

\[w_i = \sigma(Q_i - t) = \frac{1}{1 + e^{-(Q_i - t)}}\]
Quality vs Threshold Retention Weight
Q << t ~0 (filtered)
Q = t 0.5
Q >> t ~1 (retained)

Input Specifications¤

sequence¤

Property Value
Shape (length, alphabet_size)
Type Float[Array, ...]
Description One-hot encoded sequence

quality_scores¤

Property Value
Shape (length,)
Type Float[Array, ...]
Description Phred quality scores

Output Specifications¤

sequence¤

Property Value
Shape (length, alphabet_size)
Type Float[Array, ...]
Description Quality-weighted sequence

quality_scores¤

Property Value
Shape (length,)
Type Float[Array, ...]
Description Original quality scores (preserved)