Single-Cell Losses API¤
Loss functions for training single-cell analysis models.
BatchMixingLoss¤
diffbio.losses.singlecell_losses.BatchMixingLoss
¤
BatchMixingLoss(
n_neighbors: int = 15,
n_batches: int = 3,
temperature: float = 1.0,
*,
rngs: Rngs | None = None,
)
Bases: Module
Loss function to maximize batch mixing in latent space.
Computes how well batches are mixed in the embedding space by measuring the entropy of batch labels among k-nearest neighbors for each cell. Higher entropy indicates better mixing.
The loss encourages the model to learn representations where cells from different batches are interleaved, reducing batch effects.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_neighbors
|
int
|
Number of nearest neighbors to consider. |
15
|
n_batches
|
int
|
Number of batches (required for JIT compatibility). |
3
|
temperature
|
float
|
Temperature for softmax in distance computation. |
1.0
|
rngs
|
Rngs | None
|
Flax NNX random number generators. |
None
|
Example
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_neighbors
|
int
|
Number of nearest neighbors to consider. |
15
|
n_batches
|
int
|
Number of batches (static for JIT compatibility). |
3
|
temperature
|
float
|
Temperature for soft neighbor selection. |
1.0
|
rngs
|
Rngs | None
|
Random number generators (not used, for API consistency). |
None
|
__call__
¤
__call__(
embeddings: Float[Array, "n_cells latent_dim"],
batch_labels: Int[Array, n_cells],
) -> Float[Array, ""]
Compute batch mixing loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
embeddings
|
Float[Array, 'n_cells latent_dim']
|
Cell embeddings in latent space. |
required |
batch_labels
|
Int[Array, n_cells]
|
Integer batch label for each cell. |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, '']
|
Negative mean entropy of batch distribution in neighborhoods (scalar). |
Float[Array, '']
|
Lower loss means better mixing. |
ClusteringCompactnessLoss¤
diffbio.losses.singlecell_losses.ClusteringCompactnessLoss
¤
ClusteringCompactnessLoss(
separation_weight: float = 1.0,
min_separation: float = 1.0,
*,
rngs: Rngs | None = None,
)
Bases: Module
Loss function to encourage compact and well-separated clusters.
Combines two components: 1. Compactness: Minimize within-cluster variance 2. Separation: Maximize between-cluster distances
Works with soft cluster assignments for end-to-end differentiability.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
separation_weight
|
float
|
Weight for the separation term. |
1.0
|
min_separation
|
float
|
Minimum desired distance between cluster centers. |
1.0
|
rngs
|
Rngs | None
|
Flax NNX random number generators. |
None
|
Example
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
separation_weight
|
float
|
Weight for separation term. |
1.0
|
min_separation
|
float
|
Minimum desired distance between centroids. |
1.0
|
rngs
|
Rngs | None
|
Random number generators (not used, for API consistency). |
None
|
__call__
¤
__call__(
embeddings: Float[Array, "n_cells latent_dim"],
assignments: Float[Array, "n_cells n_clusters"],
centroids: Float[Array, "n_clusters latent_dim"]
| None = None,
) -> Float[Array, ""]
Compute clustering compactness loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
embeddings
|
Float[Array, 'n_cells latent_dim']
|
Cell embeddings in latent space. |
required |
assignments
|
Float[Array, 'n_cells n_clusters']
|
Soft cluster assignments (should sum to 1 per cell). |
required |
centroids
|
Float[Array, 'n_clusters latent_dim'] | None
|
Optional cluster centroids. If provided, uses these directly for gradient flow. If None, computes soft centroids from assignments. |
None
|
Returns:
| Type | Description |
|---|---|
Float[Array, '']
|
Combined compactness and separation loss (scalar). |
VelocityConsistencyLoss¤
diffbio.losses.singlecell_losses.VelocityConsistencyLoss
¤
VelocityConsistencyLoss(
dt: float = 0.1,
cosine_weight: float = 1.0,
magnitude_weight: float = 1.0,
*,
rngs: Rngs | None = None,
)
Bases: Module
Loss function to enforce consistency between velocity and trajectory.
Ensures that the predicted RNA velocity is consistent with actual expression changes over time. Combines directional (cosine) and magnitude consistency.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dt
|
float
|
Time step for velocity extrapolation. |
0.1
|
cosine_weight
|
float
|
Weight for directional consistency. |
1.0
|
magnitude_weight
|
float
|
Weight for magnitude consistency. |
1.0
|
rngs
|
Rngs | None
|
Flax NNX random number generators. |
None
|
Example
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dt
|
float
|
Time step for velocity extrapolation. |
0.1
|
cosine_weight
|
float
|
Weight for cosine similarity loss. |
1.0
|
magnitude_weight
|
float
|
Weight for magnitude loss. |
1.0
|
rngs
|
Rngs | None
|
Random number generators (not used, for API consistency). |
None
|
__call__
¤
__call__(
expression: Float[Array, "n_cells n_genes"],
velocity: Float[Array, "n_cells n_genes"],
future_expression: Float[Array, "n_cells n_genes"],
) -> Float[Array, ""]
Compute velocity consistency loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
expression
|
Float[Array, 'n_cells n_genes']
|
Current gene expression. |
required |
velocity
|
Float[Array, 'n_cells n_genes']
|
Predicted RNA velocity (rate of change). |
required |
future_expression
|
Float[Array, 'n_cells n_genes']
|
Future gene expression (ground truth or estimated). |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, '']
|
Combined directional and magnitude consistency loss (scalar). |
ShannonDiversityLoss¤
diffbio.losses.singlecell_losses.ShannonDiversityLoss
¤
Bases: Module
Mean Shannon entropy of soft cluster assignments across cells.
Measures assignment diversity using Shannon entropy. Higher values indicate more uniform (diverse) cluster assignments, while lower values indicate concentrated assignments.
Delegates to calibrax.metrics.functional.information.entropy for the
per-cell entropy computation.
Example
__call__
¤
Compute mean Shannon entropy of soft cluster assignments.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
assignments
|
Float[Array, 'n_cells n_clusters']
|
Soft cluster probabilities of shape |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, '']
|
Mean Shannon entropy across cells (scalar). Range |
Float[Array, '']
|
where |
SimpsonDiversityLoss¤
diffbio.losses.singlecell_losses.SimpsonDiversityLoss
¤
Bases: Module
Mean Simpson concentration index of soft cluster assignments.
Computes the sum of squared assignment probabilities per cell, averaged across all cells. Lower values indicate more diverse (uniform) assignments.
- Uniform assignments over K clusters yield
1/K. - Fully concentrated (one-hot) assignments yield
1.0.
Example
__call__
¤
Compute mean Simpson concentration index.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
assignments
|
Float[Array, 'n_cells n_clusters']
|
Soft cluster probabilities of shape |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, '']
|
Mean sum-of-squared-probabilities across cells (scalar). |
Float[Array, '']
|
Range |
Usage Examples¤
Batch Mixing Loss¤
from diffbio.losses import BatchMixingLoss
batch_loss = BatchMixingLoss(n_neighbors=15, temperature=1.0)
# Maximize batch mixing in latent space
loss = batch_loss(
embeddings=latent_embeddings, # (n_cells, latent_dim)
batch_labels=batch_labels, # (n_cells,)
)
Clustering Compactness Loss¤
from diffbio.losses import ClusteringCompactnessLoss
cluster_loss = ClusteringCompactnessLoss(separation_weight=1.0, min_separation=1.0)
# Encourage tight clusters
loss = cluster_loss(
embeddings=cell_embeddings,
assignments=soft_assignments,
)
Combined Training¤
def combined_loss(model, data):
result, _, _ = model.apply(data, {}, None)
# Batch mixing (negate to maximize)
l_batch = -batch_loss(result["embeddings"], data["batch_ids"])
# Clustering compactness
l_cluster = cluster_loss(result["embeddings"], result["assignments"])
return l_batch + 0.5 * l_cluster