Skip to content

Splitters API¤

Dataset splitting utilities for train/validation/test splitting in bioinformatics and drug discovery applications.

Base Classes¤

SplitterModule¤

diffbio.splitters.base.SplitterModule ¤

SplitterModule(
    config: SplitterConfig,
    *,
    rngs: Rngs | None = None,
    name: str | None = None,
)

Bases: StructuralModule

Base class for dataset splitters.

Inherits from StructuralModule because:

  • Non-parametric (no learnable parameters)
  • Frozen config (splitting strategy is fixed)
  • Uses process() method pattern
  • Integrates with Datarax data sources

Splitters divide data into train/valid/test sets, while Datarax SamplerModule controls iteration ORDER within those sets.

Parameters:

Name Type Description Default
config SplitterConfig

Splitter configuration

required
rngs Rngs | None

Random number generators for stochastic splitting

None
name str | None

Optional name for the module

None

Parameters:

Name Type Description Default
config SplitterConfig

Splitter configuration

required
rngs Rngs | None

Random number generators

None
name str | None

Optional module name

None

split ¤

split(data_source: DataSourceModule) -> SplitResult

Split a data source into train/valid/test indices.

Subclasses must implement this method.

Parameters:

Name Type Description Default
data_source DataSourceModule

Datarax DataSourceModule to split

required

Returns:

Type Description
SplitResult

SplitResult with train/valid/test indices

process ¤

process(data_source: DataSourceModule) -> SplitResult

Process data source using the split method.

Parameters:

Name Type Description Default
data_source DataSourceModule

Datarax DataSourceModule to split

required

Returns:

Type Description
SplitResult

SplitResult with train/valid/test indices

k_fold_split ¤

k_fold_split(
    data_source: DataSourceModule, k: int = 5
) -> list[tuple[ndarray, ndarray]]

K-fold cross-validation split.

Subclasses may implement this method.

Parameters:

Name Type Description Default
data_source DataSourceModule

Datarax DataSourceModule to split

required
k int

Number of folds

5

Returns:

Type Description
list[tuple[ndarray, ndarray]]

List of (train_indices, val_indices) tuples for each fold

create_split_sources ¤

create_split_sources(
    data_source: DataSourceModule,
    split_result: SplitResult | None = None,
    lazy: bool = True,
) -> tuple[
    DataSourceModule, DataSourceModule, DataSourceModule
]

Create separate data sources for each split.

This creates views into the original data source using the split indices. Each returned source can be used with Datarax samplers independently.

Parameters:

Name Type Description Default
data_source DataSourceModule

Original data source

required
split_result SplitResult | None

Pre-computed split (or compute if None)

None
lazy bool

If True, use lazy loading (IndexedViewSource). If False, eagerly load into MemorySource (faster iteration but uses memory).

True

Returns:

Type Description
tuple[DataSourceModule, DataSourceModule, DataSourceModule]

Tuple of (train_source, valid_source, test_source)

SplitterConfig¤

diffbio.splitters.base.SplitterConfig dataclass ¤

SplitterConfig(
    train_frac: float = 0.8,
    valid_frac: float = 0.1,
    test_frac: float = 0.1,
    seed: int | None = None,
)

Bases: StructuralConfig

Base configuration for splitters.

Frozen because splitters are non-parametric (StructuralModule).

Attributes:

Name Type Description
train_frac float

Fraction of data for training (default: 0.8)

valid_frac float

Fraction of data for validation (default: 0.1)

test_frac float

Fraction of data for testing (default: 0.1)

seed int | None

Random seed for reproducibility (optional)

SplitResult¤

diffbio.splitters.base.SplitResult ¤

Bases: NamedTuple

Result of a dataset split operation.

Attributes:

Name Type Description
train_indices ndarray

Array of indices for training set

valid_indices ndarray

Array of indices for validation set

test_indices ndarray

Array of indices for test set

train_indices instance-attribute ¤

train_indices: ndarray

valid_indices instance-attribute ¤

valid_indices: ndarray

test_indices instance-attribute ¤

test_indices: ndarray

train_size property ¤

train_size: int

Return number of training samples.

valid_size property ¤

valid_size: int

Return number of validation samples.

test_size property ¤

test_size: int

Return number of test samples.

Random Splitters¤

RandomSplitter¤

diffbio.splitters.random.RandomSplitter ¤

RandomSplitter(
    config: RandomSplitterConfig,
    *,
    rngs: Rngs | None = None,
    name: str | None = None,
)

Bases: SplitterModule

Simple random splitting using JAX RNG.

Uses JAX random permutation for reproducible splits. All data points are randomly assigned to train/valid/test sets according to the configured fractions.

Example
config = RandomSplitterConfig(train_frac=0.8, valid_frac=0.1, test_frac=0.1, seed=42)
splitter = RandomSplitter(config)
result = splitter.split(data_source)
print(f"Train size: {result.train_size}")

Parameters:

Name Type Description Default
config RandomSplitterConfig

Random splitter configuration

required
rngs Rngs | None

Random number generators

None
name str | None

Optional module name

None

split ¤

split(data_source: DataSourceModule) -> SplitResult

Split data source randomly.

Parameters:

Name Type Description Default
data_source DataSourceModule

Datarax DataSourceModule to split

required

Returns:

Type Description
SplitResult

SplitResult with randomly assigned train/valid/test indices

k_fold_split ¤

k_fold_split(
    data_source: DataSourceModule, k: int = 5
) -> list[tuple[ndarray, ndarray]]

K-fold cross-validation split.

Parameters:

Name Type Description Default
data_source DataSourceModule

Datarax DataSourceModule to split

required
k int

Number of folds

5

Returns:

Type Description
list[tuple[ndarray, ndarray]]

List of (train_indices, val_indices) tuples for each fold

RandomSplitterConfig¤

diffbio.splitters.random.RandomSplitterConfig dataclass ¤

RandomSplitterConfig(
    train_frac: float = 0.8,
    valid_frac: float = 0.1,
    test_frac: float = 0.1,
    seed: int | None = None,
)

Bases: SplitterConfig

Configuration for random splitter.

Inherits all fields from SplitterConfig
  • train_frac: Fraction of data for training (default: 0.8)
  • valid_frac: Fraction of data for validation (default: 0.1)
  • test_frac: Fraction of data for testing (default: 0.1)
  • seed: Random seed for reproducibility (optional)

train_frac class-attribute instance-attribute ¤

train_frac: float = 0.8

valid_frac class-attribute instance-attribute ¤

valid_frac: float = 0.1

test_frac class-attribute instance-attribute ¤

test_frac: float = 0.1

seed class-attribute instance-attribute ¤

seed: int | None = None

__post_init__ ¤

__post_init__()

Validate configuration after initialization.

StratifiedSplitter¤

diffbio.splitters.random.StratifiedSplitter ¤

StratifiedSplitter(
    config: StratifiedSplitterConfig,
    *,
    rngs: Rngs | None = None,
    name: str | None = None,
)

Bases: SplitterModule

Stratified splitting that preserves class distribution.

Ensures each split has approximately the same class distribution as the original dataset. Useful for imbalanced classification tasks.

Example
config = StratifiedSplitterConfig(seed=42, label_key="target")
splitter = StratifiedSplitter(config)
result = splitter.split(data_source)

Parameters:

Name Type Description Default
config StratifiedSplitterConfig

Stratified splitter configuration

required
rngs Rngs | None

Random number generators

None
name str | None

Optional module name

None

split ¤

split(data_source: DataSourceModule) -> SplitResult

Split preserving class distribution.

Parameters:

Name Type Description Default
data_source DataSourceModule

Datarax DataSourceModule to split

required

Returns:

Type Description
SplitResult

SplitResult with stratified train/valid/test indices

StratifiedSplitterConfig¤

diffbio.splitters.random.StratifiedSplitterConfig dataclass ¤

StratifiedSplitterConfig(
    train_frac: float = 0.8,
    valid_frac: float = 0.1,
    test_frac: float = 0.1,
    seed: int | None = None,
    label_key: str = "y",
)

Bases: SplitterConfig

Configuration for stratified splitter.

Attributes:

Name Type Description
label_key str

Key in data element containing labels (default: "y")

train_frac class-attribute instance-attribute ¤

train_frac: float = 0.8

valid_frac class-attribute instance-attribute ¤

valid_frac: float = 0.1

test_frac class-attribute instance-attribute ¤

test_frac: float = 0.1

seed class-attribute instance-attribute ¤

seed: int | None = None

__post_init__ ¤

__post_init__()

Validate configuration after initialization.

Molecular Splitters¤

ScaffoldSplitter¤

diffbio.splitters.molecular.ScaffoldSplitter ¤

ScaffoldSplitter(
    config: ScaffoldSplitterConfig,
    *,
    rngs: Rngs | None = None,
    name: str | None = None,
)

Bases: SplitterModule

Split molecules by Bemis-Murcko scaffold.

Inherits from SplitterModule (StructuralModule) because:

  • Non-parametric: scaffold extraction is deterministic
  • Frozen config: splitting strategy doesn't change
  • Domain-specific: requires RDKit and molecular knowledge

Ensures train/test sets have different molecular scaffolds, preventing data leakage from structurally similar molecules. This is the industry standard for drug discovery benchmarks.

Requires RDKit installation.

Example
config = ScaffoldSplitterConfig(smiles_key="mol_smiles")
splitter = ScaffoldSplitter(config)
result = splitter.split(molecule_source)
References

Bemis, Guy W., and Mark A. Murcko. "The properties of known drugs. 1. Molecular frameworks." Journal of medicinal chemistry 39.15 (1996): 2887-2893.

Parameters:

Name Type Description Default
config ScaffoldSplitterConfig

Scaffold splitter configuration

required
rngs Rngs | None

Random number generators (unused for scaffold splitting)

None
name str | None

Optional module name

None

Raises:

Type Description
ImportError

If RDKit is not installed

split ¤

split(data_source: DataSourceModule) -> SplitResult

Split by scaffold, largest scaffolds go to train first.

Parameters:

Name Type Description Default
data_source DataSourceModule

Datarax DataSourceModule to split

required

Returns:

Type Description
SplitResult

SplitResult with scaffold-based train/valid/test indices

ScaffoldSplitterConfig¤

diffbio.splitters.molecular.ScaffoldSplitterConfig dataclass ¤

ScaffoldSplitterConfig(
    train_frac: float = 0.8,
    valid_frac: float = 0.1,
    test_frac: float = 0.1,
    seed: int | None = None,
    smiles_key: str = "smiles",
)

Bases: SplitterConfig

Configuration for scaffold splitter.

Attributes:

Name Type Description
smiles_key str

Key in data element containing SMILES string (default: "smiles")

train_frac class-attribute instance-attribute ¤

train_frac: float = 0.8

valid_frac class-attribute instance-attribute ¤

valid_frac: float = 0.1

test_frac class-attribute instance-attribute ¤

test_frac: float = 0.1

seed class-attribute instance-attribute ¤

seed: int | None = None

__post_init__ ¤

__post_init__()

Validate configuration after initialization.

TanimotoClusterSplitter¤

diffbio.splitters.molecular.TanimotoClusterSplitter ¤

TanimotoClusterSplitter(
    config: TanimotoClusterSplitterConfig,
    *,
    rngs: Rngs | None = None,
    name: str | None = None,
)

Bases: SplitterModule

Split by Tanimoto similarity clustering (Butina algorithm).

Groups similar molecules together using fingerprint similarity, then assigns clusters to train/valid/test to ensure structural diversity between splits.

Inherits from SplitterModule (StructuralModule) because:

  • Non-parametric: clustering is deterministic given fingerprints
  • Frozen config: splitting strategy doesn't change
  • Domain-specific: requires RDKit fingerprints

Requires RDKit installation.

Example
config = TanimotoClusterSplitterConfig(similarity_cutoff=0.6)
splitter = TanimotoClusterSplitter(config)
result = splitter.split(molecule_source)
References

Butina, Darko. "Unsupervised data base clustering based on daylight's fingerprint and Tanimoto similarity." JCICS 39.4 (1999): 747-750.

Parameters:

Name Type Description Default
config TanimotoClusterSplitterConfig

Tanimoto cluster splitter configuration

required
rngs Rngs | None

Random number generators (unused)

None
name str | None

Optional module name

None

Raises:

Type Description
ImportError

If RDKit is not installed

split ¤

split(data_source: DataSourceModule) -> SplitResult

Cluster by Tanimoto similarity and split.

Parameters:

Name Type Description Default
data_source DataSourceModule

Datarax DataSourceModule to split

required

Returns:

Type Description
SplitResult

SplitResult with cluster-based train/valid/test indices

TanimotoClusterSplitterConfig¤

diffbio.splitters.molecular.TanimotoClusterSplitterConfig dataclass ¤

TanimotoClusterSplitterConfig(
    train_frac: float = 0.8,
    valid_frac: float = 0.1,
    test_frac: float = 0.1,
    seed: int | None = None,
    smiles_key: str = "smiles",
    fingerprint_type: str = "morgan",
    fingerprint_radius: int = 2,
    fingerprint_bits: int = 2048,
    similarity_cutoff: float = 0.6,
)

Bases: SplitterConfig

Configuration for Tanimoto cluster splitter.

Attributes:

Name Type Description
smiles_key str

Key in data element containing SMILES string (default: "smiles")

fingerprint_type str

Type of fingerprint ("morgan", "rdkit", "maccs")

fingerprint_radius int

Radius for Morgan fingerprints (default: 2)

fingerprint_bits int

Number of bits for fingerprints (default: 2048)

similarity_cutoff float

Tanimoto similarity cutoff for clustering (default: 0.6)

train_frac class-attribute instance-attribute ¤

train_frac: float = 0.8

valid_frac class-attribute instance-attribute ¤

valid_frac: float = 0.1

test_frac class-attribute instance-attribute ¤

test_frac: float = 0.1

seed class-attribute instance-attribute ¤

seed: int | None = None

__post_init__ ¤

__post_init__()

Validate configuration after initialization.

Sequence Splitters¤

SequenceIdentitySplitter¤

diffbio.splitters.sequence.SequenceIdentitySplitter ¤

SequenceIdentitySplitter(
    config: SequenceIdentitySplitterConfig,
    *,
    rngs: Rngs | None = None,
    name: str | None = None,
)

Bases: SplitterModule

Split sequences by identity threshold.

Groups similar sequences together using identity clustering, then assigns clusters to train/valid/test to ensure structural diversity between splits. This prevents data leakage from similar sequences appearing in different splits.

Inherits from SplitterModule (StructuralModule) because:

  • Non-parametric: clustering is deterministic
  • Frozen config: splitting strategy doesn't change
  • Domain-specific: requires sequence comparison

Similar to CD-HIT or MMseqs2 clustering approach.

Example
config = SequenceIdentitySplitterConfig(identity_threshold=0.3)
splitter = SequenceIdentitySplitter(config)
result = splitter.split(sequence_source)
References

Li, Weizhong, and Adam Godzik. "Cd-hit: a fast program for clustering and comparing large sets of protein or nucleotide sequences." Bioinformatics 22.13 (2006): 1658-1659.

Parameters:

Name Type Description Default
config SequenceIdentitySplitterConfig

Sequence identity splitter configuration

required
rngs Rngs | None

Random number generators (unused for identity splitting)

None
name str | None

Optional module name

None

split ¤

split(data_source: DataSourceModule) -> SplitResult

Split by sequence identity clustering.

Clusters sequences by identity, then assigns clusters to train/valid/test splits. Largest clusters go to train first.

Parameters:

Name Type Description Default
data_source DataSourceModule

Datarax DataSourceModule to split

required

Returns:

Type Description
SplitResult

SplitResult with identity-based train/valid/test indices

SequenceIdentitySplitterConfig¤

diffbio.splitters.sequence.SequenceIdentitySplitterConfig dataclass ¤

SequenceIdentitySplitterConfig(
    train_frac: float = 0.8,
    valid_frac: float = 0.1,
    test_frac: float = 0.1,
    seed: int | None = None,
    sequence_key: str = "sequence",
    identity_threshold: float = 0.3,
    alignment_method: str = "simple",
)

Bases: SplitterConfig

Configuration for sequence identity splitter.

Attributes:

Name Type Description
sequence_key str

Key in data element containing sequence string (default: "sequence")

identity_threshold float

Identity threshold for clustering (default: 0.3) Sequences with identity > threshold are clustered together.

alignment_method str

Method for identity computation ("simple" or "mmseqs2")

train_frac class-attribute instance-attribute ¤

train_frac: float = 0.8

valid_frac class-attribute instance-attribute ¤

valid_frac: float = 0.1

test_frac class-attribute instance-attribute ¤

test_frac: float = 0.1

seed class-attribute instance-attribute ¤

seed: int | None = None

__post_init__ ¤

__post_init__()

Validate configuration after initialization.

Usage Examples¤

Basic Random Splitting¤

from diffbio.splitters import RandomSplitter, RandomSplitterConfig
from flax import nnx

# Create splitter with 80/10/10 split
config = RandomSplitterConfig(
    train_frac=0.8,
    valid_frac=0.1,
    test_frac=0.1,
    seed=42,
)
splitter = RandomSplitter(config, rngs=nnx.Rngs(42))

# Split a data source
result = splitter.split(data_source)
print(f"Train: {result.train_size}")
print(f"Valid: {result.valid_size}")
print(f"Test: {result.test_size}")

Stratified Splitting¤

from diffbio.splitters import StratifiedSplitter, StratifiedSplitterConfig

# Preserve class distribution in splits
config = StratifiedSplitterConfig(
    train_frac=0.8,
    valid_frac=0.1,
    test_frac=0.1,
    label_key="label",  # Key in data element containing class label
    seed=42,
)
splitter = StratifiedSplitter(config, rngs=nnx.Rngs(42))
result = splitter.split(data_source)

Scaffold Splitting (Drug Discovery)¤

from diffbio.splitters import ScaffoldSplitter, ScaffoldSplitterConfig

# Split by Bemis-Murcko scaffold (requires RDKit)
config = ScaffoldSplitterConfig(
    smiles_key="smiles",
    train_frac=0.8,
    valid_frac=0.1,
    test_frac=0.1,
)
splitter = ScaffoldSplitter(config)
result = splitter.split(molecule_source)
# Similar molecules (same scaffold) end up in same split

Tanimoto Cluster Splitting¤

from diffbio.splitters import TanimotoClusterSplitter, TanimotoClusterSplitterConfig

# Split by fingerprint similarity clustering (requires RDKit)
config = TanimotoClusterSplitterConfig(
    smiles_key="smiles",
    fingerprint_type="morgan",  # or "rdkit", "maccs"
    fingerprint_radius=2,
    similarity_cutoff=0.6,
)
splitter = TanimotoClusterSplitter(config)
result = splitter.split(molecule_source)

Sequence Identity Splitting (Bioinformatics)¤

from diffbio.splitters import SequenceIdentitySplitter, SequenceIdentitySplitterConfig

# Split by sequence identity clustering
config = SequenceIdentitySplitterConfig(
    sequence_key="sequence",
    identity_threshold=0.3,  # Max identity between train/test
    alignment_method="simple",  # or "mmseqs2"
)
splitter = SequenceIdentitySplitter(config)
result = splitter.split(sequence_source)
# Similar sequences end up in same split

Creating Split Data Sources¤

from diffbio.splitters import ScaffoldSplitter, ScaffoldSplitterConfig

# Create splitter
config = ScaffoldSplitterConfig(smiles_key="smiles")
splitter = ScaffoldSplitter(config)

# Create separate data sources for each split
train_source, valid_source, test_source = splitter.create_split_sources(
    data_source,
    lazy=True,  # Use IndexedViewSource for memory efficiency
)

# Use with Datarax samplers
from datarax.samplers import ShuffleSampler, ShuffleSamplerConfig
sampler_config = ShuffleSamplerConfig(batch_size=32)
train_sampler = ShuffleSampler(sampler_config, data_source=train_source)

K-Fold Cross-Validation¤

from diffbio.splitters import RandomSplitter, RandomSplitterConfig

config = RandomSplitterConfig(seed=42)
splitter = RandomSplitter(config, rngs=nnx.Rngs(42))

# Get 5-fold splits
folds = splitter.k_fold_split(data_source, k=5)

for fold_idx, (train_indices, val_indices) in enumerate(folds):
    print(f"Fold {fold_idx}: {len(train_indices)} train, {len(val_indices)} val")

Input Specifications¤

All Splitters¤

Parameter Type Description
data_source DataSourceModule Datarax data source to split

ScaffoldSplitter / TanimotoClusterSplitter¤

Data elements must contain:

Key Type Description
smiles_key str SMILES string for the molecule

SequenceIdentitySplitter¤

Data elements must contain:

Key Type Description
sequence_key str DNA/RNA/protein sequence

StratifiedSplitter¤

Data elements must contain:

Key Type Description
label_key int/str Class label for stratification

Output Specifications¤

SplitResult¤

Field Type Description
train_indices jnp.ndarray Indices for training set
valid_indices jnp.ndarray Indices for validation set
test_indices jnp.ndarray Indices for test set

Splitter Comparison¤

Splitter Use Case Domain Dependencies
RandomSplitter General purpose Any None
StratifiedSplitter Class-imbalanced data Any None
ScaffoldSplitter Drug discovery Chemistry RDKit
TanimotoClusterSplitter Molecular similarity Chemistry RDKit
SequenceIdentitySplitter Genomics/Proteomics Biology None