Data Sources¤
DiffBio provides data source modules that integrate with the Datarax framework for loading bioinformatics and drug discovery datasets.
Data Loading Structural Module
Overview¤
Data sources in DiffBio extend datarax.core.data_source.DataSourceModule, providing:
- Consistent Interface: All sources implement
__len__,__getitem__, and__iter__ - Lazy Loading: Data is loaded on-demand to minimize memory usage
- Metadata Support: Each element includes source-specific metadata
- Sampler Integration: Direct compatibility with Datarax samplers
Source Hierarchy¤
DataSourceModule (Datarax)
├── BAMSource # BAM/CRAM file reading
├── FastaSource # FASTA file reading
├── MolNetSource # MoleculeNet benchmark datasets
├── AnnDataSource # AnnData (.h5ad) single-cell data
└── IndexedViewSource # Lazy view over subset of another source
Genomics Sources¤
DiffBio provides optimized data sources for common genomics file formats, built on widely used libraries.
BAMSource - Aligned Reads¤
BAMSource reads aligned sequencing reads from BAM/CRAM files using pysam, a lightweight wrapper around HTSlib.
Quick Start¤
from pathlib import Path
from diffbio.sources import BAMSource, BAMSourceConfig
# Load aligned reads
config = BAMSourceConfig(
file_path=Path("sample.bam"),
min_mapping_quality=20,
)
source = BAMSource(config)
print(f"Reads: {len(source)}")
for element in source:
seq = element.data["sequence"] # One-hot encoded (length, 4)
qual = element.data["quality_scores"] # Phred scores (length,)
name = element.data["read_name"]
Configuration Options¤
from pathlib import Path
from diffbio.sources import BAMSourceConfig
config = BAMSourceConfig(
# Required: path to BAM/CRAM file
file_path=Path("sample.bam"),
# Optional: reference FASTA (required for CRAM)
reference_path=Path("reference.fa"),
# Filter reads by mapping quality
min_mapping_quality=20,
# Include unmapped reads
include_unmapped=False,
# Query specific region only
region="chr1:1000000-2000000",
# How to handle N nucleotides
handle_n="uniform", # or "zero"
)
Data Element Format¤
element = source[0]
# Read sequence as one-hot encoding
sequence = element.data["sequence"] # shape: (read_length, 4)
quality = element.data["quality_scores"] # shape: (read_length,)
name = element.data["read_name"] # str
# Metadata
idx = element.metadata["idx"]
reference = element.metadata["reference_name"] # e.g., "chr1"
position = element.metadata["reference_start"] # 0-based
mapq = element.metadata["mapping_quality"]
Performance Tips¤
- Use indexed BAM files: Ensure
.baiindex exists for random access - Filter by region: Use
regionparameter to load only needed reads - Filter at load time: Use
min_mapping_qualityinstead of post-filtering - Use iterators: Process reads one at a time, don't load all into memory
FastaSource - Reference Sequences¤
FastaSource reads DNA/RNA sequences from FASTA files using pyfaidx, providing samtools-compatible indexed access.
Quick Start¤
from pathlib import Path
from diffbio.sources import FastaSource, FastaSourceConfig
# Load reference genome
config = FastaSourceConfig(
file_path=Path("genome.fasta"),
)
source = FastaSource(config)
print(f"Chromosomes: {source.sequence_names}")
# Access by name
chr1 = source.get_by_name("chr1")
sequence = chr1.data["sequence"] # One-hot encoded
Configuration Options¤
from pathlib import Path
from diffbio.sources import FastaSourceConfig
config = FastaSourceConfig(
# Required: path to FASTA file
file_path=Path("genome.fasta"),
# How to handle N nucleotides
handle_n="uniform", # [0.25, 0.25, 0.25, 0.25]
# handle_n="zero", # [0, 0, 0, 0]
# Create .fai index if missing
create_index=True,
)
Data Element Format¤
element = source[0]
# Sequence as one-hot encoding
sequence = element.data["sequence"] # shape: (seq_length, 4)
seq_id = element.data["sequence_id"] # str: "chr1"
description = element.data["description"] # str: full header
# Metadata
idx = element.metadata["idx"]
length = element.metadata["length"]
Access Patterns¤
from diffbio.sources import FastaSource, FastaSourceConfig
config = FastaSourceConfig(file_path=Path("genome.fasta"))
source = FastaSource(config)
# 1. Iterate over all sequences
for element in source:
print(f"{element.data['sequence_id']}: {element.metadata['length']} bp")
# 2. Access by index
first_seq = source[0]
# 3. Access by name
chr1 = source.get_by_name("chr1")
# 4. List all sequence names
names = source.sequence_names # ["chr1", "chr2", ...]
# 5. Batch access
batch = source.get_batch(10)
Performance Tips¤
- Use indexed FASTA:
.faiindex enables O(1) random access - Lazy loading: Sequences loaded only when accessed
- Access by name: Use
get_by_name()for specific chromosomes - BGZF compression: Works with bgzip-compressed files
Integration with DiffBio Operators¤
Genomics sources output one-hot encoded sequences compatible with DiffBio operators:
from diffbio.sources import FastaSource, FastaSourceConfig
from diffbio.operators.alignment import SmoothSmithWaterman, SmithWatermanConfig
# Load sequences
fasta = FastaSource(FastaSourceConfig(file_path=Path("genome.fasta")))
seq1 = fasta.get_by_name("seq1").data["sequence"]
seq2 = fasta.get_by_name("seq2").data["sequence"]
# Align with differentiable Smith-Waterman
aligner = SmoothSmithWaterman(SmithWatermanConfig())
result, _, _ = aligner.apply(
{"query": seq1, "reference": seq2},
{},
None,
)
score = result["score"]
Installation¤
Genomics sources require optional dependencies:
# Install genomics dependencies
uv pip install -e ".[genomics]"
# Or install individually
pip install pysam pyfaidx
AnnData Source (Single-Cell)¤
Overview¤
AnnDataSource loads single-cell RNA-seq data from .h5ad files (AnnData format) and converts them to JAX-compatible data dicts suitable for DiffBio operators. Follows the datarax eager-loading pattern: all data is loaded to JAX arrays at init, then iteration/batching uses pure JAX operations.
Quick Start¤
from diffbio.sources import AnnDataSource, AnnDataSourceConfig
config = AnnDataSourceConfig(file_path="pbmc3k.h5ad")
source = AnnDataSource(config)
print(f"Cells: {len(source)}") # 2700
print(f"Shape: {source.load()['counts'].shape}") # (2700, 32738)
# Iterate over cells
for cell in source:
print(cell["counts"].shape) # (32738,)
break
# Get a batch
batch = source.get_batch(32)
print(batch["counts"].shape) # (32, 32738)
Configuration¤
| Parameter | Type | Default | Description |
|---|---|---|---|
file_path |
str | required | Path to the .h5ad file |
backed |
bool | False | Open in memory-mapped mode |
shuffle |
bool | False | Shuffle during iteration |
seed |
int | 42 | Seed for index shuffling |
split |
str | None | Optional split name |
Output Format¤
| Key | Type | Description |
|---|---|---|
counts |
JAX array (n_cells, n_genes) | Dense expression matrix from .X |
obs |
dict of arrays | Cell metadata columns from .obs |
var |
dict of arrays | Gene metadata columns from .var |
obsm |
dict of JAX arrays | Embeddings from .obsm (PCA, UMAP, etc.) |
AnnData Interop Functions¤
DiffBio provides bidirectional conversion between its standard data dict format and AnnData objects for interoperability with scanpy, scvi-tools, and other AnnData-based tools.
to_anndata: DiffBio dict to AnnData¤
from diffbio.sources.anndata_interop import to_anndata
# Convert DiffBio data dict to AnnData
data_dict = source.load()
adata = to_anndata(data_dict)
# Use with scanpy
import scanpy as sc
sc.pp.normalize_total(adata)
sc.pp.log1p(adata)
from_anndata: AnnData to DiffBio dict¤
from diffbio.sources.anndata_interop import from_anndata
# Convert AnnData to DiffBio data dict
import scanpy as sc
adata = sc.read_h5ad("data.h5ad")
data_dict = from_anndata(adata)
# Use with DiffBio operators
result, _, _ = operator.apply({"counts": data_dict["counts"]}, {}, None)
Integration with Single-Cell Operators¤
from diffbio.sources import AnnDataSource, AnnDataSourceConfig
from diffbio.operators.singlecell import SoftKMeansClustering, SoftClusteringConfig
# Load data
source = AnnDataSource(AnnDataSourceConfig(file_path="pbmc3k.h5ad"))
data = source.load()
# Cluster
clustering = SoftKMeansClustering(
SoftClusteringConfig(n_clusters=10, n_features=data["counts"].shape[1]),
rngs=nnx.Rngs(42),
)
result, _, _ = clustering.apply({"embeddings": data["counts"]}, {}, None)
Installation¤
AnnDataSource requires optional dependencies:
MolNet Benchmark Datasets¤
Overview¤
MolNetSource provides access to the MoleculeNet benchmark suite - the standard benchmark collection for molecular machine learning.
Quick Start¤
from diffbio.sources import MolNetSource, MolNetSourceConfig
# Load the BBBP (Blood-Brain Barrier Penetration) dataset
config = MolNetSourceConfig(
dataset_name="bbbp",
split="train",
download=True,
)
source = MolNetSource(config)
print(f"Dataset: {len(source)} molecules")
print(f"Task type: {source.task_type}")
print(f"Number of tasks: {source.n_tasks}")
# Access individual molecules
element = source[0]
print(f"SMILES: {element.data['smiles']}")
print(f"Label: {element.data['y']}")
Available Datasets¤
Classification Benchmarks¤
| Dataset | Tasks | Description | Molecules |
|---|---|---|---|
bbbp |
1 | Blood-brain barrier penetration | ~2,000 |
tox21 |
12 | Toxicity across 12 assays | ~8,000 |
hiv |
1 | HIV replication inhibition | ~40,000 |
bace |
1 | BACE-1 inhibitor activity | ~1,500 |
clintox |
2 | Clinical trial toxicity | ~1,500 |
sider |
27 | Drug side effects | ~1,400 |
Regression Benchmarks¤
| Dataset | Tasks | Description | Molecules |
|---|---|---|---|
esol |
1 | Aqueous solubility (log mol/L) | ~1,100 |
freesolv |
1 | Hydration free energy (kcal/mol) | ~640 |
lipophilicity |
1 | Octanol/water partition coefficient | ~4,200 |
Configuration Options¤
from pathlib import Path
from diffbio.sources import MolNetSourceConfig
config = MolNetSourceConfig(
# Required: dataset name
dataset_name="tox21",
# Which split to load
split="train", # "train", "valid", or "test"
# Custom data directory (default: ~/.diffbio/molnet)
data_dir=Path("/path/to/data"),
# Auto-download if not found
download=True,
)
Data Element Format¤
Each element from MolNetSource is a DataElement with:
element = source[0]
# Molecular data
smiles = element.data["smiles"] # str: SMILES representation
labels = element.data["y"] # float or array: task labels
# Metadata
idx = element.metadata["idx"] # int: index in split
dataset = element.metadata["dataset"] # str: dataset name
# State (for stateful processing)
state = element.state # dict: empty by default
Multi-Task Datasets¤
Some datasets (tox21, sider, clintox) have multiple prediction tasks:
from diffbio.sources import MolNetSource, MolNetSourceConfig
# Load Tox21 with 12 toxicity tasks
config = MolNetSourceConfig(dataset_name="tox21", split="train")
source = MolNetSource(config)
print(f"Number of tasks: {source.n_tasks}") # 12
element = source[0]
labels = element.data["y"] # Array of shape (12,)
# NaN values indicate missing labels for that task
IndexedViewSource¤
Overview¤
IndexedViewSource provides a lazy view over a subset of another data source. This is particularly useful for:
- Memory Efficiency: Don't duplicate data when splitting
- Split Views: Access train/valid/test splits as separate sources
- Shuffled Access: Optionally shuffle iteration order
Basic Usage¤
from diffbio.sources import IndexedViewSource, IndexedViewSourceConfig
import jax.numpy as jnp
# Create a view of specific indices
indices = jnp.array([0, 5, 10, 15, 20])
config = IndexedViewSourceConfig(
shuffle=False, # Preserve index order
)
view = IndexedViewSource(config, parent_source, indices)
print(f"View size: {len(view)}") # 5
element = view[0] # Loads from parent_source[0]
Shuffled Iteration¤
from diffbio.sources import IndexedViewSource, IndexedViewSourceConfig
from flax import nnx
# Shuffle for training
config = IndexedViewSourceConfig(
shuffle=True,
seed=42, # Reproducible shuffling
)
view = IndexedViewSource(config, parent_source, train_indices, rngs=nnx.Rngs(42))
# Each epoch sees data in different order
for element in view:
# Process shuffled elements
pass
Integration with Splitters¤
IndexedViewSource is typically created automatically by splitters:
from diffbio.sources import MolNetSource, MolNetSourceConfig
from diffbio.splitters import ScaffoldSplitter, ScaffoldSplitterConfig
# Load full dataset
source = MolNetSource(MolNetSourceConfig(dataset_name="bbbp"))
# Split creates IndexedViewSource instances
splitter = ScaffoldSplitter(ScaffoldSplitterConfig(smiles_key="smiles"))
train_source, valid_source, test_source = splitter.create_split_sources(
source,
lazy=True, # Returns IndexedViewSource (memory efficient)
)
# Use as regular data sources
print(f"Train: {len(train_source)} molecules")
print(f"Valid: {len(valid_source)} molecules")
print(f"Test: {len(test_source)} molecules")
Integration with Datarax Samplers¤
Batch Training Setup¤
from diffbio.sources import MolNetSource, MolNetSourceConfig
from diffbio.splitters import ScaffoldSplitter, ScaffoldSplitterConfig
from datarax.samplers import ShuffleSampler, ShuffleSamplerConfig
from flax import nnx
# 1. Load dataset
source = MolNetSource(MolNetSourceConfig(dataset_name="bbbp"))
# 2. Split by molecular scaffold
splitter = ScaffoldSplitter(ScaffoldSplitterConfig(smiles_key="smiles"))
train_source, valid_source, test_source = splitter.create_split_sources(
source, lazy=True
)
# 3. Create samplers
train_sampler = ShuffleSampler(
ShuffleSamplerConfig(batch_size=32, shuffle=True),
data_source=train_source,
rngs=nnx.Rngs(42),
)
valid_sampler = ShuffleSampler(
ShuffleSamplerConfig(batch_size=32, shuffle=False),
data_source=valid_source,
)
# 4. Training loop
for epoch in range(100):
# Training
for batch in train_sampler:
smiles_batch = [elem.data["smiles"] for elem in batch]
labels_batch = [elem.data["y"] for elem in batch]
# Train on batch
# Validation
for batch in valid_sampler:
# Evaluate on batch
pass
Custom Collation¤
import jax.numpy as jnp
from diffbio.operators.drug_discovery import batch_smiles_to_graphs
def collate_molecules(batch):
"""Custom collation for molecular graphs."""
smiles_list = [elem.data["smiles"] for elem in batch]
labels = jnp.array([elem.data["y"] for elem in batch])
# Convert SMILES to molecular graphs
graphs = batch_smiles_to_graphs(smiles_list)
return {
"graphs": graphs,
"labels": labels,
}
# Use with sampler
for batch_elements in train_sampler:
batch = collate_molecules(batch_elements)
# batch["graphs"] contains padded molecular graphs
# batch["labels"] contains label array
Best Practices¤
1. Use Lazy Loading for Large Datasets¤
# GOOD: Lazy loading with IndexedViewSource
train, valid, test = splitter.create_split_sources(source, lazy=True)
# LESS OPTIMAL: Eager loading copies all data
train, valid, test = splitter.create_split_sources(source, lazy=False)
2. Set Seeds for Reproducibility¤
from diffbio.sources import MolNetSourceConfig
from diffbio.splitters import RandomSplitterConfig
# Consistent splits across runs
splitter_config = RandomSplitterConfig(seed=42)
# Consistent shuffling
view_config = IndexedViewSourceConfig(shuffle=True, seed=42)
3. Check Dataset Properties¤
source = MolNetSource(MolNetSourceConfig(dataset_name="tox21"))
# Understand the task
print(f"Task type: {source.task_type}") # "classification"
print(f"Number of tasks: {source.n_tasks}") # 12
# Handle missing labels in multi-task datasets
import jax.numpy as jnp
for elem in source:
labels = elem.data["y"]
valid_mask = ~jnp.isnan(labels)
# Only compute loss for valid labels
4. Use Domain-Appropriate Splitting¤
# Drug discovery: Split by molecular structure
from diffbio.splitters import ScaffoldSplitter
splitter = ScaffoldSplitter(ScaffoldSplitterConfig(smiles_key="smiles"))
# Bioinformatics: Split by sequence identity
from diffbio.splitters import SequenceIdentitySplitter
splitter = SequenceIdentitySplitter(SequenceIdentitySplitterConfig(
sequence_key="sequence",
identity_threshold=0.3,
))
Troubleshooting¤
Dataset Download Issues¤
# Specify custom directory if default fails
config = MolNetSourceConfig(
dataset_name="bbbp",
data_dir=Path("/writable/directory"),
download=True,
)
Missing Labels¤
import jax.numpy as jnp
# Multi-task datasets may have missing labels
element = source[0]
labels = element.data["y"]
# Check for NaN values
valid_mask = ~jnp.isnan(labels)
valid_labels = labels[valid_mask]
Memory Issues with Large Datasets¤
# Use lazy loading
train, valid, test = splitter.create_split_sources(source, lazy=True)
# Process in batches rather than loading all at once
for batch in sampler:
# Process batch
pass
Source Selection Guide¤
| Use Case | Recommended Source |
|---|---|
| BAM/CRAM aligned reads | BAMSource |
| FASTA reference sequences | FastaSource |
| MolNet benchmarks | MolNetSource |
| Single-cell .h5ad files | AnnDataSource |
| Split views | IndexedViewSource (via splitter) |
| Custom datasets | Extend DataSourceModule |
Related Resources¤
Dataset Splitting¤
- Dataset Splitters: Domain-aware splitting for unbiased evaluation
- ScaffoldSplitter: Drug discovery scaffold-based splitting
- SequenceIdentitySplitter: Bioinformatics sequence-based splitting
Operators¤
- Drug Discovery Operators: Use molecular graphs from MolNetSource
- Alignment Operators: Align sequences from FastaSource
- Variant Operators: Process reads from BAMSource
API Reference¤
- Data Sources API: Complete API documentation for all data sources
- Dataset Splitters API: Complete API documentation for all splitters