Quick Start¤
This guide walks you through the basics of using DiffBio for differentiable bioinformatics.
Your First Alignment¤
Let's compute a differentiable Smith-Waterman alignment:
import jax.numpy as jnp
from diffbio.operators.alignment import SmoothSmithWaterman, SmithWatermanConfig
from diffbio.operators.alignment import create_dna_scoring_matrix
# Create a scoring matrix for DNA sequences
# A=0, C=1, G=2, T=3
scoring_matrix = create_dna_scoring_matrix(match=2.0, mismatch=-1.0)
# Configure the aligner
config = SmithWatermanConfig(
temperature=1.0, # Smoothness parameter
gap_open=-10.0, # Gap opening penalty
gap_extend=-1.0 # Gap extension penalty
)
# Create the differentiable aligner
aligner = SmoothSmithWaterman(config, scoring_matrix=scoring_matrix)
# One-hot encode sequences
# Sequence: ACGT -> [[1,0,0,0], [0,1,0,0], [0,0,1,0], [0,0,0,1]]
def one_hot_dna(sequence_indices):
return jnp.eye(4)[sequence_indices]
seq1 = one_hot_dna(jnp.array([0, 1, 2, 3, 0, 1])) # ACGTAC
seq2 = one_hot_dna(jnp.array([0, 1, 0, 3, 0, 1])) # ACATAC
# Perform alignment via the Datarax apply() interface
data = {"seq1": seq1, "seq2": seq2}
result, _, _ = aligner.apply(data, {}, None)
print(f"Alignment score: {result['score']:.4f}")
print(f"Alignment matrix shape: {result['alignment_matrix'].shape}")
print(f"Soft alignment shape: {result['soft_alignment'].shape}")
Computing Gradients¤
The key feature of DiffBio is that all operations are differentiable:
import jax
# Define a loss function based on alignment score
def alignment_loss(scoring_matrix, seq1, seq2):
config = SmithWatermanConfig(temperature=1.0)
aligner = SmoothSmithWaterman(config, scoring_matrix=scoring_matrix)
data = {"seq1": seq1, "seq2": seq2}
result, _, _ = aligner.apply(data, {}, None)
return -result["score"] # Negative because we want to maximize
# Compute gradients with respect to the scoring matrix
grad_fn = jax.grad(alignment_loss)
grads = grad_fn(scoring_matrix, seq1, seq2)
print(f"Gradient shape: {grads.shape}")
print(f"Gradient:\n{grads}")
Using the Datarax Interface¤
DiffBio operators implement the Datarax OperatorModule interface for batch processing:
from diffbio.operators.alignment import SmoothSmithWaterman, SmithWatermanConfig
from diffbio.operators.alignment import create_dna_scoring_matrix
# Setup
config = SmithWatermanConfig(temperature=1.0)
scoring = create_dna_scoring_matrix(match=2.0, mismatch=-1.0)
aligner = SmoothSmithWaterman(config, scoring_matrix=scoring)
# Prepare data as dictionary (Datarax format)
data = {
"seq1": seq1,
"seq2": seq2,
}
state = {}
metadata = None
# Apply operator
result_data, state, metadata = aligner.apply(data, state, metadata)
print(f"Score: {result_data['score']:.4f}")
Quality Filtering¤
Apply soft quality filtering to sequence data:
from diffbio.operators import DifferentiableQualityFilter, QualityFilterConfig
# Create quality filter
config = QualityFilterConfig(initial_threshold=20.0) # Phred 20
filter_op = DifferentiableQualityFilter(config)
# Prepare data
data = {
"sequence": seq1, # One-hot encoded sequence
"quality_scores": jnp.array([30.0, 25.0, 10.0, 35.0, 20.0, 28.0]),
}
# Apply filtering
filtered_data, _, _ = filter_op.apply(data, {}, None)
# Low-quality positions are down-weighted
print(f"Original sequence sum: {seq1.sum():.2f}")
print(f"Filtered sequence sum: {filtered_data['sequence'].sum():.2f}")
Pileup Generation¤
Generate differentiable pileups from aligned reads:
from diffbio.operators.variant import DifferentiablePileup, PileupConfig
# Configure pileup generator
config = PileupConfig(
use_quality_weights=True,
reference_length=50
)
pileup_op = DifferentiablePileup(config)
# Simulate some aligned reads
num_reads = 10
read_length = 20
reads = jax.random.uniform(
jax.random.PRNGKey(0),
(num_reads, read_length, 4)
)
reads = jax.nn.softmax(reads, axis=-1) # Normalize to distributions
positions = jax.random.randint(
jax.random.PRNGKey(1),
(num_reads,),
minval=0,
maxval=30
)
quality = jax.random.uniform(
jax.random.PRNGKey(2),
(num_reads, read_length),
minval=10.0,
maxval=40.0
)
# Generate pileup
data = {"reads": reads, "positions": positions, "quality": quality}
result, _, _ = pileup_op.apply(data, {}, None)
print(f"Pileup shape: {result['pileup'].shape}") # (reference_length, 4)
End-to-End Pipeline¤
Combine operators into a differentiable pipeline:
import jax
from diffbio.operators.alignment import (
SmoothSmithWaterman, SmithWatermanConfig,
)
from diffbio.operators import (
DifferentiableQualityFilter, QualityFilterConfig,
)
from diffbio.operators.alignment import create_dna_scoring_matrix
def pipeline(params, seq1, seq2, quality1, quality2):
# Step 1: Quality filtering
filter_config = QualityFilterConfig(initial_threshold=params['threshold'])
filter_op = DifferentiableQualityFilter(filter_config)
filtered1, _, _ = filter_op.apply(
{"sequence": seq1, "quality_scores": quality1}, {}, None
)
filtered2, _, _ = filter_op.apply(
{"sequence": seq2, "quality_scores": quality2}, {}, None
)
# Step 2: Alignment
align_config = SmithWatermanConfig(temperature=params['temperature'])
aligner = SmoothSmithWaterman(align_config, scoring_matrix=params['scoring'])
align_data = {
"seq1": filtered1['sequence'],
"seq2": filtered2['sequence'],
}
result, _, _ = aligner.apply(align_data, {}, None)
return result["score"]
# Initialize parameters
params = {
'threshold': 20.0,
'temperature': 1.0,
'scoring': create_dna_scoring_matrix(match=2.0, mismatch=-1.0),
}
# Compute gradients through the entire pipeline
grad_fn = jax.grad(pipeline)
quality1 = jnp.ones(6) * 30.0
quality2 = jnp.ones(6) * 30.0
grads = grad_fn(params, seq1, seq2, quality1, quality2)
print("Gradients computed through entire pipeline!")
print(f"Threshold gradient: {grads['threshold']:.6f}")
print(f"Temperature gradient: {grads['temperature']:.6f}")
Next Steps¤
- Learn about Core Concepts behind differentiable bioinformatics
- Explore the User Guide for in-depth documentation
- See Examples for more use cases
- Read the API Reference for complete API details