Splitters API¤
Dataset splitting utilities for train/validation/test splitting in bioinformatics and drug discovery applications.
Base Classes¤
SplitterModule¤
diffbio.splitters.base.SplitterModule
¤
SplitterModule(
config: SplitterConfig,
*,
rngs: Rngs | None = None,
name: str | None = None,
)
Bases: StructuralModule
Base class for dataset splitters.
Inherits from StructuralModule because:
- Non-parametric (no learnable parameters)
- Frozen config (splitting strategy is fixed)
- Uses process() method pattern
- Integrates with Datarax data sources
Splitters divide data into train/valid/test sets, while Datarax SamplerModule controls iteration ORDER within those sets.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
SplitterConfig
|
Splitter configuration |
required |
rngs
|
Rngs | None
|
Random number generators for stochastic splitting |
None
|
name
|
str | None
|
Optional name for the module |
None
|
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
SplitterConfig
|
Splitter configuration |
required |
rngs
|
Rngs | None
|
Random number generators |
None
|
name
|
str | None
|
Optional module name |
None
|
split
¤
split(data_source: DataSourceModule) -> SplitResult
Split a data source into train/valid/test indices.
Subclasses must implement this method.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data_source
|
DataSourceModule
|
Datarax DataSourceModule to split |
required |
Returns:
| Type | Description |
|---|---|
SplitResult
|
SplitResult with train/valid/test indices |
process
¤
process(data_source: DataSourceModule) -> SplitResult
Process data source using the split method.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data_source
|
DataSourceModule
|
Datarax DataSourceModule to split |
required |
Returns:
| Type | Description |
|---|---|
SplitResult
|
SplitResult with train/valid/test indices |
k_fold_split
¤
K-fold cross-validation split.
Subclasses may implement this method.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data_source
|
DataSourceModule
|
Datarax DataSourceModule to split |
required |
k
|
int
|
Number of folds |
5
|
Returns:
| Type | Description |
|---|---|
list[tuple[ndarray, ndarray]]
|
List of (train_indices, val_indices) tuples for each fold |
create_split_sources
¤
create_split_sources(
data_source: DataSourceModule,
split_result: SplitResult | None = None,
lazy: bool = True,
) -> tuple[
DataSourceModule, DataSourceModule, DataSourceModule
]
Create separate data sources for each split.
This creates views into the original data source using the split indices. Each returned source can be used with Datarax samplers independently.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data_source
|
DataSourceModule
|
Original data source |
required |
split_result
|
SplitResult | None
|
Pre-computed split (or compute if None) |
None
|
lazy
|
bool
|
If True, use lazy loading (IndexedViewSource). If False, eagerly load into MemorySource (faster iteration but uses memory). |
True
|
Returns:
| Type | Description |
|---|---|
tuple[DataSourceModule, DataSourceModule, DataSourceModule]
|
Tuple of (train_source, valid_source, test_source) |
SplitterConfig¤
diffbio.splitters.base.SplitterConfig
dataclass
¤
SplitterConfig(
train_frac: float = 0.8,
valid_frac: float = 0.1,
test_frac: float = 0.1,
seed: int | None = None,
)
Bases: StructuralConfig
Base configuration for splitters.
Frozen because splitters are non-parametric (StructuralModule).
Attributes:
| Name | Type | Description |
|---|---|---|
train_frac |
float
|
Fraction of data for training (default: 0.8) |
valid_frac |
float
|
Fraction of data for validation (default: 0.1) |
test_frac |
float
|
Fraction of data for testing (default: 0.1) |
seed |
int | None
|
Random seed for reproducibility (optional) |
SplitResult¤
diffbio.splitters.base.SplitResult
¤
Bases: NamedTuple
Result of a dataset split operation.
Attributes:
| Name | Type | Description |
|---|---|---|
train_indices |
ndarray
|
Array of indices for training set |
valid_indices |
ndarray
|
Array of indices for validation set |
test_indices |
ndarray
|
Array of indices for test set |
Random Splitters¤
RandomSplitter¤
diffbio.splitters.random.RandomSplitter
¤
RandomSplitter(
config: RandomSplitterConfig,
*,
rngs: Rngs | None = None,
name: str | None = None,
)
Bases: SplitterModule
Simple random splitting using JAX RNG.
Uses JAX random permutation for reproducible splits. All data points are randomly assigned to train/valid/test sets according to the configured fractions.
Example
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
RandomSplitterConfig
|
Random splitter configuration |
required |
rngs
|
Rngs | None
|
Random number generators |
None
|
name
|
str | None
|
Optional module name |
None
|
split
¤
split(data_source: DataSourceModule) -> SplitResult
Split data source randomly.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data_source
|
DataSourceModule
|
Datarax DataSourceModule to split |
required |
Returns:
| Type | Description |
|---|---|
SplitResult
|
SplitResult with randomly assigned train/valid/test indices |
k_fold_split
¤
K-fold cross-validation split.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data_source
|
DataSourceModule
|
Datarax DataSourceModule to split |
required |
k
|
int
|
Number of folds |
5
|
Returns:
| Type | Description |
|---|---|
list[tuple[ndarray, ndarray]]
|
List of (train_indices, val_indices) tuples for each fold |
RandomSplitterConfig¤
diffbio.splitters.random.RandomSplitterConfig
dataclass
¤
RandomSplitterConfig(
train_frac: float = 0.8,
valid_frac: float = 0.1,
test_frac: float = 0.1,
seed: int | None = None,
)
Bases: SplitterConfig
Configuration for random splitter.
Inherits all fields from SplitterConfig
- train_frac: Fraction of data for training (default: 0.8)
- valid_frac: Fraction of data for validation (default: 0.1)
- test_frac: Fraction of data for testing (default: 0.1)
- seed: Random seed for reproducibility (optional)
StratifiedSplitter¤
diffbio.splitters.random.StratifiedSplitter
¤
StratifiedSplitter(
config: StratifiedSplitterConfig,
*,
rngs: Rngs | None = None,
name: str | None = None,
)
Bases: SplitterModule
Stratified splitting that preserves class distribution.
Ensures each split has approximately the same class distribution as the original dataset. Useful for imbalanced classification tasks.
Example
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
StratifiedSplitterConfig
|
Stratified splitter configuration |
required |
rngs
|
Rngs | None
|
Random number generators |
None
|
name
|
str | None
|
Optional module name |
None
|
split
¤
split(data_source: DataSourceModule) -> SplitResult
Split preserving class distribution.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data_source
|
DataSourceModule
|
Datarax DataSourceModule to split |
required |
Returns:
| Type | Description |
|---|---|
SplitResult
|
SplitResult with stratified train/valid/test indices |
StratifiedSplitterConfig¤
diffbio.splitters.random.StratifiedSplitterConfig
dataclass
¤
StratifiedSplitterConfig(
train_frac: float = 0.8,
valid_frac: float = 0.1,
test_frac: float = 0.1,
seed: int | None = None,
label_key: str = "y",
)
Bases: SplitterConfig
Configuration for stratified splitter.
Attributes:
| Name | Type | Description |
|---|---|---|
label_key |
str
|
Key in data element containing labels (default: "y") |
Molecular Splitters¤
ScaffoldSplitter¤
diffbio.splitters.molecular.ScaffoldSplitter
¤
ScaffoldSplitter(
config: ScaffoldSplitterConfig,
*,
rngs: Rngs | None = None,
name: str | None = None,
)
Bases: SplitterModule
Split molecules by Bemis-Murcko scaffold.
Inherits from SplitterModule (StructuralModule) because:
- Non-parametric: scaffold extraction is deterministic
- Frozen config: splitting strategy doesn't change
- Domain-specific: requires RDKit and molecular knowledge
Ensures train/test sets have different molecular scaffolds, preventing data leakage from structurally similar molecules. This is the industry standard for drug discovery benchmarks.
Requires RDKit installation.
Example
References
Bemis, Guy W., and Mark A. Murcko. "The properties of known drugs. 1. Molecular frameworks." Journal of medicinal chemistry 39.15 (1996): 2887-2893.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
ScaffoldSplitterConfig
|
Scaffold splitter configuration |
required |
rngs
|
Rngs | None
|
Random number generators (unused for scaffold splitting) |
None
|
name
|
str | None
|
Optional module name |
None
|
Raises:
| Type | Description |
|---|---|
ImportError
|
If RDKit is not installed |
split
¤
split(data_source: DataSourceModule) -> SplitResult
Split by scaffold, largest scaffolds go to train first.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data_source
|
DataSourceModule
|
Datarax DataSourceModule to split |
required |
Returns:
| Type | Description |
|---|---|
SplitResult
|
SplitResult with scaffold-based train/valid/test indices |
ScaffoldSplitterConfig¤
diffbio.splitters.molecular.ScaffoldSplitterConfig
dataclass
¤
ScaffoldSplitterConfig(
train_frac: float = 0.8,
valid_frac: float = 0.1,
test_frac: float = 0.1,
seed: int | None = None,
smiles_key: str = "smiles",
)
Bases: SplitterConfig
Configuration for scaffold splitter.
Attributes:
| Name | Type | Description |
|---|---|---|
smiles_key |
str
|
Key in data element containing SMILES string (default: "smiles") |
TanimotoClusterSplitter¤
diffbio.splitters.molecular.TanimotoClusterSplitter
¤
TanimotoClusterSplitter(
config: TanimotoClusterSplitterConfig,
*,
rngs: Rngs | None = None,
name: str | None = None,
)
Bases: SplitterModule
Split by Tanimoto similarity clustering (Butina algorithm).
Groups similar molecules together using fingerprint similarity, then assigns clusters to train/valid/test to ensure structural diversity between splits.
Inherits from SplitterModule (StructuralModule) because:
- Non-parametric: clustering is deterministic given fingerprints
- Frozen config: splitting strategy doesn't change
- Domain-specific: requires RDKit fingerprints
Requires RDKit installation.
Example
References
Butina, Darko. "Unsupervised data base clustering based on daylight's fingerprint and Tanimoto similarity." JCICS 39.4 (1999): 747-750.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
TanimotoClusterSplitterConfig
|
Tanimoto cluster splitter configuration |
required |
rngs
|
Rngs | None
|
Random number generators (unused) |
None
|
name
|
str | None
|
Optional module name |
None
|
Raises:
| Type | Description |
|---|---|
ImportError
|
If RDKit is not installed |
split
¤
split(data_source: DataSourceModule) -> SplitResult
Cluster by Tanimoto similarity and split.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data_source
|
DataSourceModule
|
Datarax DataSourceModule to split |
required |
Returns:
| Type | Description |
|---|---|
SplitResult
|
SplitResult with cluster-based train/valid/test indices |
TanimotoClusterSplitterConfig¤
diffbio.splitters.molecular.TanimotoClusterSplitterConfig
dataclass
¤
TanimotoClusterSplitterConfig(
train_frac: float = 0.8,
valid_frac: float = 0.1,
test_frac: float = 0.1,
seed: int | None = None,
smiles_key: str = "smiles",
fingerprint_type: str = "morgan",
fingerprint_radius: int = 2,
fingerprint_bits: int = 2048,
similarity_cutoff: float = 0.6,
)
Bases: SplitterConfig
Configuration for Tanimoto cluster splitter.
Attributes:
| Name | Type | Description |
|---|---|---|
smiles_key |
str
|
Key in data element containing SMILES string (default: "smiles") |
fingerprint_type |
str
|
Type of fingerprint ("morgan", "rdkit", "maccs") |
fingerprint_radius |
int
|
Radius for Morgan fingerprints (default: 2) |
fingerprint_bits |
int
|
Number of bits for fingerprints (default: 2048) |
similarity_cutoff |
float
|
Tanimoto similarity cutoff for clustering (default: 0.6) |
Sequence Splitters¤
SequenceIdentitySplitter¤
diffbio.splitters.sequence.SequenceIdentitySplitter
¤
SequenceIdentitySplitter(
config: SequenceIdentitySplitterConfig,
*,
rngs: Rngs | None = None,
name: str | None = None,
)
Bases: SplitterModule
Split sequences by identity threshold.
Groups similar sequences together using identity clustering, then assigns clusters to train/valid/test to ensure structural diversity between splits. This prevents data leakage from similar sequences appearing in different splits.
Inherits from SplitterModule (StructuralModule) because:
- Non-parametric: clustering is deterministic
- Frozen config: splitting strategy doesn't change
- Domain-specific: requires sequence comparison
Similar to CD-HIT or MMseqs2 clustering approach.
Example
References
Li, Weizhong, and Adam Godzik. "Cd-hit: a fast program for clustering and comparing large sets of protein or nucleotide sequences." Bioinformatics 22.13 (2006): 1658-1659.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
SequenceIdentitySplitterConfig
|
Sequence identity splitter configuration |
required |
rngs
|
Rngs | None
|
Random number generators (unused for identity splitting) |
None
|
name
|
str | None
|
Optional module name |
None
|
split
¤
split(data_source: DataSourceModule) -> SplitResult
Split by sequence identity clustering.
Clusters sequences by identity, then assigns clusters to train/valid/test splits. Largest clusters go to train first.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data_source
|
DataSourceModule
|
Datarax DataSourceModule to split |
required |
Returns:
| Type | Description |
|---|---|
SplitResult
|
SplitResult with identity-based train/valid/test indices |
SequenceIdentitySplitterConfig¤
diffbio.splitters.sequence.SequenceIdentitySplitterConfig
dataclass
¤
SequenceIdentitySplitterConfig(
train_frac: float = 0.8,
valid_frac: float = 0.1,
test_frac: float = 0.1,
seed: int | None = None,
sequence_key: str = "sequence",
identity_threshold: float = 0.3,
alignment_method: str = "simple",
)
Bases: SplitterConfig
Configuration for sequence identity splitter.
Attributes:
| Name | Type | Description |
|---|---|---|
sequence_key |
str
|
Key in data element containing sequence string (default: "sequence") |
identity_threshold |
float
|
Identity threshold for clustering (default: 0.3) Sequences with identity > threshold are clustered together. |
alignment_method |
str
|
Method for identity computation ("simple" or "mmseqs2") |
Usage Examples¤
Basic Random Splitting¤
from diffbio.splitters import RandomSplitter, RandomSplitterConfig
from flax import nnx
# Create splitter with 80/10/10 split
config = RandomSplitterConfig(
train_frac=0.8,
valid_frac=0.1,
test_frac=0.1,
seed=42,
)
splitter = RandomSplitter(config, rngs=nnx.Rngs(42))
# Split a data source
result = splitter.split(data_source)
print(f"Train: {result.train_size}")
print(f"Valid: {result.valid_size}")
print(f"Test: {result.test_size}")
Stratified Splitting¤
from diffbio.splitters import StratifiedSplitter, StratifiedSplitterConfig
# Preserve class distribution in splits
config = StratifiedSplitterConfig(
train_frac=0.8,
valid_frac=0.1,
test_frac=0.1,
label_key="label", # Key in data element containing class label
seed=42,
)
splitter = StratifiedSplitter(config, rngs=nnx.Rngs(42))
result = splitter.split(data_source)
Scaffold Splitting (Drug Discovery)¤
from diffbio.splitters import ScaffoldSplitter, ScaffoldSplitterConfig
# Split by Bemis-Murcko scaffold (requires RDKit)
config = ScaffoldSplitterConfig(
smiles_key="smiles",
train_frac=0.8,
valid_frac=0.1,
test_frac=0.1,
)
splitter = ScaffoldSplitter(config)
result = splitter.split(molecule_source)
# Similar molecules (same scaffold) end up in same split
Tanimoto Cluster Splitting¤
from diffbio.splitters import TanimotoClusterSplitter, TanimotoClusterSplitterConfig
# Split by fingerprint similarity clustering (requires RDKit)
config = TanimotoClusterSplitterConfig(
smiles_key="smiles",
fingerprint_type="morgan", # or "rdkit", "maccs"
fingerprint_radius=2,
similarity_cutoff=0.6,
)
splitter = TanimotoClusterSplitter(config)
result = splitter.split(molecule_source)
Sequence Identity Splitting (Bioinformatics)¤
from diffbio.splitters import SequenceIdentitySplitter, SequenceIdentitySplitterConfig
# Split by sequence identity clustering
config = SequenceIdentitySplitterConfig(
sequence_key="sequence",
identity_threshold=0.3, # Max identity between train/test
alignment_method="simple", # or "mmseqs2"
)
splitter = SequenceIdentitySplitter(config)
result = splitter.split(sequence_source)
# Similar sequences end up in same split
Creating Split Data Sources¤
from diffbio.splitters import ScaffoldSplitter, ScaffoldSplitterConfig
# Create splitter
config = ScaffoldSplitterConfig(smiles_key="smiles")
splitter = ScaffoldSplitter(config)
# Create separate data sources for each split
train_source, valid_source, test_source = splitter.create_split_sources(
data_source,
lazy=True, # Use IndexedViewSource for memory efficiency
)
# Use with Datarax samplers
from datarax.samplers import ShuffleSampler, ShuffleSamplerConfig
sampler_config = ShuffleSamplerConfig(batch_size=32)
train_sampler = ShuffleSampler(sampler_config, data_source=train_source)
K-Fold Cross-Validation¤
from diffbio.splitters import RandomSplitter, RandomSplitterConfig
config = RandomSplitterConfig(seed=42)
splitter = RandomSplitter(config, rngs=nnx.Rngs(42))
# Get 5-fold splits
folds = splitter.k_fold_split(data_source, k=5)
for fold_idx, (train_indices, val_indices) in enumerate(folds):
print(f"Fold {fold_idx}: {len(train_indices)} train, {len(val_indices)} val")
Input Specifications¤
All Splitters¤
| Parameter | Type | Description |
|---|---|---|
data_source |
DataSourceModule | Datarax data source to split |
ScaffoldSplitter / TanimotoClusterSplitter¤
Data elements must contain:
| Key | Type | Description |
|---|---|---|
smiles_key |
str | SMILES string for the molecule |
SequenceIdentitySplitter¤
Data elements must contain:
| Key | Type | Description |
|---|---|---|
sequence_key |
str | DNA/RNA/protein sequence |
StratifiedSplitter¤
Data elements must contain:
| Key | Type | Description |
|---|---|---|
label_key |
int/str | Class label for stratification |
Output Specifications¤
SplitResult¤
| Field | Type | Description |
|---|---|---|
train_indices |
jnp.ndarray | Indices for training set |
valid_indices |
jnp.ndarray | Indices for validation set |
test_indices |
jnp.ndarray | Indices for test set |
Splitter Comparison¤
| Splitter | Use Case | Domain | Dependencies |
|---|---|---|---|
| RandomSplitter | General purpose | Any | None |
| StratifiedSplitter | Class-imbalanced data | Any | None |
| ScaffoldSplitter | Drug discovery | Chemistry | RDKit |
| TanimotoClusterSplitter | Molecular similarity | Chemistry | RDKit |
| SequenceIdentitySplitter | Genomics/Proteomics | Biology | None |