Training Utilities API¤
Training utilities for differentiable bioinformatics pipelines.
Trainer¤
diffbio.utils.training.Trainer
¤
Trainer(pipeline: OperatorModule, config: TrainingConfig)
Training loop for DiffBio pipelines using Flax NNX patterns.
This class handles the training loop using NNX's stateful approach: - Uses nnx.Optimizer for automatic parameter updates - Uses @nnx.jit for JIT compilation with state management - Supports gradient clipping and metric logging
Example
pipeline = create_variant_calling_pipeline(reference_length=100)
trainer = Trainer(pipeline, TrainingConfig(learning_rate=1e-3))
# Define loss function
def loss_fn(predictions, targets):
return cross_entropy_loss(
predictions["logits"],
targets["labels"],
)
# Train
trainer.train(data_iterator_fn, loss_fn)
trained_pipeline = trainer.pipeline
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pipeline
|
OperatorModule
|
Pipeline to train |
required |
config
|
TrainingConfig
|
Training configuration |
required |
train
¤
Run full training loop.
After training, the pipeline is updated in-place with trained parameters. Access via trainer.pipeline.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data_iterator_fn
|
Callable
|
Function that returns a fresh data iterator |
required |
loss_fn
|
Callable
|
Loss function |
required |
train_epoch
¤
train_epoch(
data_iterator: Iterator[
tuple[dict[str, Array], dict[str, Array]]
],
loss_fn: Callable,
) -> dict[str, float]
Train for one epoch.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data_iterator
|
Iterator[tuple[dict[str, Array], dict[str, Array]]]
|
Iterator yielding (batch_data, targets) tuples |
required |
loss_fn
|
Callable
|
Loss function |
required |
Returns:
| Type | Description |
|---|---|
dict[str, float]
|
Dict of epoch metrics |
Configuration¤
TrainingConfig¤
diffbio.utils.training.TrainingConfig
dataclass
¤
TrainingConfig(
learning_rate: float = 0.001,
num_epochs: int = 100,
log_every: int = 10,
grad_clip_norm: float | None = 1.0,
)
Configuration for training loop.
Attributes:
| Name | Type | Description |
|---|---|---|
learning_rate |
float
|
Learning rate for optimizer |
num_epochs |
int
|
Number of training epochs |
log_every |
int
|
Log metrics every N steps |
grad_clip_norm |
float | None
|
Maximum gradient norm (None to disable) |
TrainingState¤
diffbio.utils.training.TrainingState
dataclass
¤
TrainingState(
step: int = 0,
epoch: int = 0,
loss_history: list[float] | None = None,
best_loss: float = float("inf"),
)
State maintained during training.
Attributes:
| Name | Type | Description |
|---|---|---|
step |
int
|
Current training step |
epoch |
int
|
Current epoch |
loss_history |
list[float] | None
|
List of loss values |
best_loss |
float
|
Best loss seen so far |
Loss Functions¤
cross_entropy_loss¤
diffbio.utils.training.cross_entropy_loss
¤
cross_entropy_loss(
logits: Float[Array, "... num_classes"],
labels: Float[Array, ...],
num_classes: int = 3,
) -> Float[Array, ""]
Compute cross-entropy loss for variant classification.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
logits
|
Float[Array, '... num_classes']
|
Raw model predictions |
required |
labels
|
Float[Array, ...]
|
Integer class labels |
required |
num_classes
|
int
|
Number of classes |
3
|
Returns:
| Type | Description |
|---|---|
Float[Array, '']
|
Scalar loss value |
Optimizer Utilities¤
create_optax_optimizer¤
diffbio.utils.training.create_optax_optimizer
¤
create_optax_optimizer(
config: TrainingConfig,
) -> GradientTransformation
Create optax optimizer with optional gradient clipping.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
TrainingConfig
|
Training configuration |
required |
Returns:
| Type | Description |
|---|---|
GradientTransformation
|
Optax optimizer |
Data Utilities¤
create_synthetic_training_data¤
diffbio.utils.training.create_synthetic_training_data
¤
create_synthetic_training_data(
num_samples: int = 100,
num_reads: int = 10,
read_length: int = 50,
reference_length: int = 100,
variant_rate: float = 0.1,
seed: int = 42,
) -> tuple[list[dict[str, Array]], list[dict[str, Array]]]
Create synthetic training data for variant calling.
Generates reads with simulated variants for training.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
num_samples
|
int
|
Number of samples to generate |
100
|
num_reads
|
int
|
Number of reads per sample |
10
|
read_length
|
int
|
Length of each read |
50
|
reference_length
|
int
|
Length of reference sequence |
100
|
variant_rate
|
float
|
Probability of variant at each position |
0.1
|
seed
|
int
|
Random seed |
42
|
Returns:
| Type | Description |
|---|---|
list[dict[str, Array]]
|
Tuple of (inputs, targets) where: |
list[dict[str, Array]]
|
|
tuple[list[dict[str, Array]], list[dict[str, Array]]]
|
|
data_iterator¤
diffbio.utils.training.data_iterator
¤
data_iterator(
inputs: list[dict[str, Array]],
targets: list[dict[str, Array]],
batch_size: int = 1,
) -> Iterator[tuple[dict[str, Array], dict[str, Array]]]
Create an iterator over training data.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
inputs
|
list[dict[str, Array]]
|
List of input dicts |
required |
targets
|
list[dict[str, Array]]
|
List of target dicts |
required |
batch_size
|
int
|
Batch size (currently only supports 1) |
1
|
Yields:
| Type | Description |
|---|---|
tuple[dict[str, Array], dict[str, Array]]
|
Tuples of (batch_data, targets) |
Usage Examples¤
Complete Training Example¤
from diffbio.pipelines import create_variant_calling_pipeline
from diffbio.utils.training import (
Trainer,
TrainingConfig,
cross_entropy_loss,
create_synthetic_training_data,
data_iterator,
)
# Create pipeline
pipeline = create_variant_calling_pipeline(reference_length=100)
# Create trainer
config = TrainingConfig(
learning_rate=1e-3,
num_epochs=50,
log_every=10,
)
trainer = Trainer(pipeline, config)
# Generate data
inputs, targets = create_synthetic_training_data(
num_samples=100,
reference_length=100,
)
# Define loss
def loss_fn(predictions, targets):
return cross_entropy_loss(
predictions["logits"],
targets["labels"],
)
# Train
trainer.train(
data_iterator_fn=lambda: data_iterator(inputs, targets),
loss_fn=loss_fn,
)
# Get trained model
trained = trainer.pipeline
Custom Training Loop¤
import jax
import optax
from flax import nnx
optimizer = optax.adam(1e-3)
opt_state = optimizer.init(nnx.state(pipeline, nnx.Param))
@jax.jit
def train_step(pipeline, opt_state, batch, targets):
def loss_fn(model):
result, _, _ = model.apply(batch, {}, None)
return cross_entropy_loss(result["logits"], targets["labels"])
loss, grads = jax.value_and_grad(loss_fn)(pipeline)
params = nnx.state(pipeline, nnx.Param)
updates, opt_state = optimizer.update(grads, opt_state, params)
nnx.update(pipeline, optax.apply_updates(params, updates))
return loss, opt_state