Epigenomics Operators¤
DiffBio provides differentiable operators for epigenomic analysis, including peak calling, chromatin state annotation, and contextual sequence models that combine sequence, TF context, and chromatin contacts.
Epigenomics Fully Differentiable
Overview¤
Epigenomics operators enable end-to-end optimization of:
- DifferentiablePeakCaller: CNN-based peak detection for ChIP-seq/ATAC-seq
- ChromatinStateAnnotator: HMM-based chromatin state classification
- ContextualEpigenomicsOperator: Artifex-transformer-based contextual sequence model with TF conditioning and optional chromatin-guidance loss
ContextualEpigenomicsOperator¤
Configurable contextual epigenomics operator that supports three ablation modes from one code path:
- sequence-only
- sequence plus TF context
- sequence plus TF context plus chromatin-guidance loss
The operator reuses Artifex's TransformerEncoder for sequence modeling and
adds a structured chromatin-consistency term instead of introducing a separate
local attention stack for contact guidance.
Quick Start¤
from flax import nnx
from diffbio.operators.epigenomics import (
ContextualEpigenomicsConfig,
ContextualEpigenomicsOperator,
compute_contextual_epigenomics_loss,
)
from diffbio.sources import build_synthetic_contextual_epigenomics_dataset
data = build_synthetic_contextual_epigenomics_dataset(
n_examples=4,
sequence_length=24,
num_tf_features=3,
target_semantics="binary_peak_mask",
num_output_classes=1,
)
config = ContextualEpigenomicsConfig(
hidden_dim=64,
num_layers=2,
num_heads=4,
intermediate_dim=256,
max_length=24,
num_tf_features=3,
num_outputs=1,
use_tf_context=True,
use_chromatin_guidance=True,
chromatin_guidance_weight=0.1,
)
operator = ContextualEpigenomicsOperator(config, rngs=nnx.Rngs(42))
result, state, metadata = operator.apply(data, {}, None)
losses = compute_contextual_epigenomics_loss(operator, data)
logits = result["logits"] # (batch, length)
token_embeddings = result["token_embeddings"] # (batch, length, hidden_dim)
guidance_loss = result["chromatin_guidance_loss"] # scalar
total_loss = losses["total"]
Configuration¤
| Parameter | Type | Default | Description |
|---|---|---|---|
hidden_dim |
int | 64 | Token embedding width |
num_layers |
int | 2 | Number of Artifex transformer layers |
num_heads |
int | 4 | Number of attention heads |
intermediate_dim |
int | 256 | Feed-forward hidden dimension |
max_length |
int | 512 | Maximum supported sequence length |
num_tf_features |
int | 8 | TF-context feature count |
num_outputs |
int | 1 | Output channels per genomic position |
use_tf_context |
bool | True | Enable TF-conditioning path |
use_chromatin_guidance |
bool | False | Enable chromatin-consistency term |
chromatin_guidance_weight |
float | 0.1 | Weight on chromatin guidance |
Inputs And Outputs¤
Expected data keys:
sequence: one-hot sequence tensor(batch, length, 4)or(length, 4)tf_context: TF features(batch, n_tf_features)or(n_tf_features,)chromatin_contacts: contact map(batch, length, length)or(length, length)targets: supervised targets(batch, length)or(length,)sequence_mask: optional binary mask for padded positions
The shared source contract validates shape, leading-dimension alignment,
one-hot sequence normalization, symmetric contact maps, and declared target
semantics. binary_peak_mask targets must be binary, while
chromatin_state_id targets must be integer class IDs within the task's
declared num_output_classes.
Main outputs:
embeddings: pooled sequence embeddingstoken_embeddings: per-position contextualized embeddingslogits: per-position prediction logitschromatin_guidance_loss: unweighted structured contact-consistency term
Training Notes¤
compute_contextual_epigenomics_loss() returns three values:
supervised: BCE or multiclass cross-entropy depending onnum_outputschromatin_guidance: structured consistency term derived from token-embedding similarity and the supplied contact maptotal:supervised + chromatin_guidance_weight * chromatin_guidance
This keeps the three ablation modes on one implementation path instead of forking separate operator classes.
Benchmark Coverage¤
The contextual epigenomics benchmark family now evaluates the real operator in three reproducible ablations:
sequence_onlytf_contexttf_plus_chromatin
The quick-suite reports track downstream task metrics plus a bounded
chromatin-consistency score derived from the structured guidance loss. This is
the supported benchmark path for comparing whether TF conditioning improves the
task metric and whether chromatin guidance improves structural consistency
without introducing a second operator implementation.
Suite reports also carry one contextual_contract block with the canonical
required source keys, target semantics by task, and output-class counts by
task, so peak-calling and chromatin-state comparisons stay on the same I/O
contract.
For regression tracking, the suite can be mirrored into a Calibrax Store with
save_contextual_epigenomics_suite_run() and checked with
check_contextual_epigenomics_suite_regressions(). Those stored comparisons
use dataset, task, and contextual_variant as the comparison axes.
The current contextual suite uses the deterministic synthetic contextual epigenomics source. Treat its ablation gains as reproducible regression evidence only; they are not stable biological promotion evidence until a real cell-type-resolved epigenomics dataset and promotion gate are added.
DifferentiablePeakCaller¤
CNN-based peak calling with learnable filters for ChIP-seq and ATAC-seq data.
Quick Start¤
from flax import nnx
from diffbio.operators.epigenomics import DifferentiablePeakCaller, PeakCallerConfig
# Configure peak caller
config = PeakCallerConfig(
num_filters=32,
kernel_sizes=[3, 5, 7],
threshold=0.5,
temperature=1.0,
)
# Create operator
rngs = nnx.Rngs(42)
peak_caller = DifferentiablePeakCaller(config, rngs=rngs)
# Apply to signal track
data = {"signal": signal} # (length,) or (batch, length)
result, state, metadata = peak_caller.apply(data, {}, None)
# Get peaks
peak_scores = result["peak_scores"] # Per-position peak probability
peak_calls = result["peak_calls"] # Soft peak calls
boundaries = result["peak_boundaries"] # Peak start/end positions
Configuration¤
| Parameter | Type | Default | Description |
|---|---|---|---|
num_filters |
int | 32 | Number of convolutional filters |
kernel_sizes |
list | [3, 5, 7] | Multi-scale kernel sizes |
threshold |
float | 0.5 | Peak calling threshold |
temperature |
float | 1.0 | Temperature for soft operations |
VAE Denoising Mode¤
The peak caller supports an optional VAE denoising step that preprocesses the coverage signal before peak detection. When enabled, a Poisson decoder (following the SCALE approach) models count data, and the denoised signal is used for downstream peak calling. This is configured via the use_vae_denoising parameter.
Architecture¤
graph LR
A["Input Signal"] --> B["VAE Denoising<br/>(optional)"]
B --> C["[Conv1D + ReLU]<br/>× 3 scales"]
C --> D["Concat"]
D --> E["Dense"]
E --> F["Sigmoid"]
F --> G["Peak Scores"]
style A fill:#d1fae5,stroke:#059669,color:#064e3b
style B fill:#ede9fe,stroke:#7c3aed,color:#4c1d95
style C fill:#e0e7ff,stroke:#4338ca,color:#312e81
style D fill:#e0e7ff,stroke:#4338ca,color:#312e81
style E fill:#e0e7ff,stroke:#4338ca,color:#312e81
style F fill:#e0e7ff,stroke:#4338ca,color:#312e81
style G fill:#d1fae5,stroke:#059669,color:#064e3b
Training for Peak Calling¤
import jax
from flax import nnx
def peak_loss(caller, signals, labels):
"""Binary cross-entropy for peak calling."""
data = {"signal": signals}
result, _, _ = caller.apply(data, {}, None)
peak_probs = result["peak_scores"]
# Binary cross-entropy
loss = -jnp.mean(
labels * jnp.log(peak_probs + 1e-8) +
(1 - labels) * jnp.log(1 - peak_probs + 1e-8)
)
return loss
# Compute gradients
grads = nnx.grad(peak_loss)(peak_caller, train_signals, train_labels)
ChromatinStateAnnotator¤
HMM-based chromatin state annotation from histone modification data.
Quick Start¤
from diffbio.operators.epigenomics import ChromatinStateAnnotator, ChromatinStateConfig
# Configure annotator
config = ChromatinStateConfig(
num_states=15, # e.g., ChromHMM default
num_marks=6, # Number of histone marks
temperature=1.0,
)
# Create operator
rngs = nnx.Rngs(42)
annotator = ChromatinStateAnnotator(config, rngs=rngs)
# Apply to histone mark data
data = {"histone_marks": marks} # (length, num_marks)
result, state, metadata = annotator.apply(data, {}, None)
# Get chromatin states
state_probs = result["state_probabilities"] # (length, num_states)
posteriors = result["state_posteriors"] # Forward-backward posteriors
viterbi_path = result["viterbi_path"] # Most likely state sequence
log_likelihood = result["log_likelihood"] # Sequence log-likelihood
Configuration¤
| Parameter | Type | Default | Description |
|---|---|---|---|
num_states |
int | 15 | Number of chromatin states |
num_marks |
int | 6 | Number of histone marks |
temperature |
float | 1.0 | Temperature for soft Viterbi |
HMM Components¤
The ChromatinStateAnnotator learns:
- Transition Matrix: State-to-state transition probabilities
- Emission Matrix: Mark presence probability per state
- Initial Distribution: Starting state probabilities
Chromatin State Interpretation¤
Common chromatin states learned by the model:
| State Type | Typical Marks | Function |
|---|---|---|
| Active Promoter | H3K4me3, H3K27ac | Gene activation |
| Strong Enhancer | H3K4me1, H3K27ac | Distal regulation |
| Weak Enhancer | H3K4me1 | Poised enhancer |
| Transcribed | H3K36me3 | Actively transcribed |
| Heterochromatin | H3K9me3 | Repressed regions |
| Polycomb | H3K27me3 | Developmental silencing |
Differentiability Techniques¤
Soft Peak Boundaries¤
Instead of discrete peak start/end detection:
# Soft peak start: rising edge detection
start_probs = jax.nn.sigmoid((scores[1:] - scores[:-1]) / temperature)
# Soft peak end: falling edge detection
end_probs = jax.nn.sigmoid((scores[:-1] - scores[1:]) / temperature)
Forward-Backward Algorithm¤
ChromatinStateAnnotator uses log-space forward-backward:
- Forward: Compute P(observations up to t, state at t)
- Backward: Compute P(observations after t | state at t)
- Posteriors: Combine for P(state at t | all observations)
All operations use logsumexp for numerical stability.
Use Cases¤
| Application | Operator | Description |
|---|---|---|
| ChIP-seq peak calling | DifferentiablePeakCaller | Find protein binding sites |
| ATAC-seq analysis | DifferentiablePeakCaller | Identify open chromatin |
| Chromatin annotation | ChromatinStateAnnotator | Classify regulatory elements |
| Enhancer prediction | ChromatinStateAnnotator | Find active enhancers |
| Context-aware regulatory modeling | ContextualEpigenomicsOperator | Combine sequence, TF context, and chromatin guidance |
Next Steps¤
- See Statistical Operators for more HMM applications
- Explore Multi-omics Operators for Hi-C analysis