Skip to content

Training Utilities API¤

Training utilities for differentiable bioinformatics pipelines.

Trainer¤

diffbio.utils.training.Trainer ¤

Trainer(pipeline: OperatorModule, config: TrainingConfig)

Training loop for DiffBio pipelines using Flax NNX patterns.

This class handles the training loop using NNX's stateful approach: - Uses nnx.Optimizer for automatic parameter updates - Uses @nnx.jit for JIT compilation with state management - Supports gradient clipping and metric logging

Example
pipeline = create_variant_calling_pipeline(reference_length=100)
trainer = Trainer(pipeline, TrainingConfig(learning_rate=1e-3))
# Define loss function
def loss_fn(predictions, targets):
    return cross_entropy_loss(
        predictions["logits"],
        targets["labels"],
    )
# Train
trainer.train(data_iterator_fn, loss_fn)
trained_pipeline = trainer.pipeline

Parameters:

Name Type Description Default
pipeline OperatorModule

Pipeline to train

required
config TrainingConfig

Training configuration

required

train ¤

train(
    data_iterator_fn: Callable, loss_fn: Callable
) -> None

Run full training loop.

After training, the pipeline is updated in-place with trained parameters. Access via trainer.pipeline.

Parameters:

Name Type Description Default
data_iterator_fn Callable

Function that returns a fresh data iterator

required
loss_fn Callable

Loss function

required

train_epoch ¤

train_epoch(
    data_iterator: Iterator[
        tuple[dict[str, Array], dict[str, Array]]
    ],
    loss_fn: Callable,
) -> dict[str, float]

Train for one epoch.

Parameters:

Name Type Description Default
data_iterator Iterator[tuple[dict[str, Array], dict[str, Array]]]

Iterator yielding (batch_data, targets) tuples

required
loss_fn Callable

Loss function

required

Returns:

Type Description
dict[str, float]

Dict of epoch metrics

Configuration¤

TrainingConfig¤

diffbio.utils.training.TrainingConfig dataclass ¤

TrainingConfig(
    learning_rate: float = 0.001,
    num_epochs: int = 100,
    log_every: int = 10,
    grad_clip_norm: float | None = 1.0,
)

Configuration for training loop.

Attributes:

Name Type Description
learning_rate float

Learning rate for optimizer

num_epochs int

Number of training epochs

log_every int

Log metrics every N steps

grad_clip_norm float | None

Maximum gradient norm (None to disable)

TrainingState¤

diffbio.utils.training.TrainingState dataclass ¤

TrainingState(
    step: int = 0,
    epoch: int = 0,
    loss_history: list[float] | None = None,
    best_loss: float = float("inf"),
)

State maintained during training.

Attributes:

Name Type Description
step int

Current training step

epoch int

Current epoch

loss_history list[float] | None

List of loss values

best_loss float

Best loss seen so far

Loss Functions¤

cross_entropy_loss¤

diffbio.utils.training.cross_entropy_loss ¤

cross_entropy_loss(
    logits: Float[Array, "... num_classes"],
    labels: Float[Array, ...],
    num_classes: int = 3,
) -> Float[Array, ""]

Compute cross-entropy loss for variant classification.

Parameters:

Name Type Description Default
logits Float[Array, '... num_classes']

Raw model predictions

required
labels Float[Array, ...]

Integer class labels

required
num_classes int

Number of classes

3

Returns:

Type Description
Float[Array, '']

Scalar loss value

Optimizer Utilities¤

create_optax_optimizer¤

diffbio.utils.training.create_optax_optimizer ¤

create_optax_optimizer(
    config: TrainingConfig,
) -> GradientTransformation

Create optax optimizer with optional gradient clipping.

Parameters:

Name Type Description Default
config TrainingConfig

Training configuration

required

Returns:

Type Description
GradientTransformation

Optax optimizer

Data Utilities¤

create_synthetic_training_data¤

diffbio.utils.training.create_synthetic_training_data ¤

create_synthetic_training_data(
    num_samples: int = 100,
    num_reads: int = 10,
    read_length: int = 50,
    reference_length: int = 100,
    variant_rate: float = 0.1,
    seed: int = 42,
) -> tuple[list[dict[str, Array]], list[dict[str, Array]]]

Create synthetic training data for variant calling.

Generates reads with simulated variants for training.

Parameters:

Name Type Description Default
num_samples int

Number of samples to generate

100
num_reads int

Number of reads per sample

10
read_length int

Length of each read

50
reference_length int

Length of reference sequence

100
variant_rate float

Probability of variant at each position

0.1
seed int

Random seed

42

Returns:

Type Description
list[dict[str, Array]]

Tuple of (inputs, targets) where:

list[dict[str, Array]]
  • inputs: List of dicts with reads, positions, quality
tuple[list[dict[str, Array]], list[dict[str, Array]]]
  • targets: List of dicts with labels (0=ref, 1=snp, 2=indel)

data_iterator¤

diffbio.utils.training.data_iterator ¤

data_iterator(
    inputs: list[dict[str, Array]],
    targets: list[dict[str, Array]],
    batch_size: int = 1,
) -> Iterator[tuple[dict[str, Array], dict[str, Array]]]

Create an iterator over training data.

Parameters:

Name Type Description Default
inputs list[dict[str, Array]]

List of input dicts

required
targets list[dict[str, Array]]

List of target dicts

required
batch_size int

Batch size (currently only supports 1)

1

Yields:

Type Description
tuple[dict[str, Array], dict[str, Array]]

Tuples of (batch_data, targets)

Usage Examples¤

Complete Training Example¤

from diffbio.pipelines import create_variant_calling_pipeline
from diffbio.utils.training import (
    Trainer,
    TrainingConfig,
    cross_entropy_loss,
    create_synthetic_training_data,
    data_iterator,
)

# Create pipeline
pipeline = create_variant_calling_pipeline(reference_length=100)

# Create trainer
config = TrainingConfig(
    learning_rate=1e-3,
    num_epochs=50,
    log_every=10,
)
trainer = Trainer(pipeline, config)

# Generate data
inputs, targets = create_synthetic_training_data(
    num_samples=100,
    reference_length=100,
)

# Define loss
def loss_fn(predictions, targets):
    return cross_entropy_loss(
        predictions["logits"],
        targets["labels"],
    )

# Train
trainer.train(
    data_iterator_fn=lambda: data_iterator(inputs, targets),
    loss_fn=loss_fn,
)

# Get trained model
trained = trainer.pipeline

Custom Training Loop¤

import jax
import optax
from flax import nnx

optimizer = optax.adam(1e-3)
opt_state = optimizer.init(nnx.state(pipeline, nnx.Param))

@jax.jit
def train_step(pipeline, opt_state, batch, targets):
    def loss_fn(model):
        result, _, _ = model.apply(batch, {}, None)
        return cross_entropy_loss(result["logits"], targets["labels"])

    loss, grads = jax.value_and_grad(loss_fn)(pipeline)
    params = nnx.state(pipeline, nnx.Param)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    nnx.update(pipeline, optax.apply_updates(params, updates))
    return loss, opt_state

Module Exports¤

from diffbio.utils.training import (
    # Core
    Trainer,
    TrainingConfig,
    TrainingState,

    # Loss
    cross_entropy_loss,

    # Optimizer
    create_optax_optimizer,

    # Data
    create_synthetic_training_data,
    data_iterator,
)