Skip to content

Single-Cell Losses API¤

Loss functions for training single-cell analysis models.

BatchMixingLoss¤

diffbio.losses.singlecell_losses.BatchMixingLoss ¤

BatchMixingLoss(
    n_neighbors: int = 15,
    n_batches: int = 3,
    temperature: float = 1.0,
    *,
    rngs: Rngs | None = None,
)

Bases: Module

Loss function to maximize batch mixing in latent space.

Computes how well batches are mixed in the embedding space by measuring the entropy of batch labels among k-nearest neighbors for each cell. Higher entropy indicates better mixing.

The loss encourages the model to learn representations where cells from different batches are interleaved, reducing batch effects.

Parameters:

Name Type Description Default
n_neighbors int

Number of nearest neighbors to consider.

15
n_batches int

Number of batches (required for JIT compatibility).

3
temperature float

Temperature for softmax in distance computation.

1.0
rngs Rngs | None

Flax NNX random number generators.

None
Example
loss_fn = BatchMixingLoss(n_neighbors=15, n_batches=3, rngs=nnx.Rngs(42))
loss = loss_fn(embeddings, batch_labels)

Parameters:

Name Type Description Default
n_neighbors int

Number of nearest neighbors to consider.

15
n_batches int

Number of batches (static for JIT compatibility).

3
temperature float

Temperature for soft neighbor selection.

1.0
rngs Rngs | None

Random number generators (not used, for API consistency).

None

__call__ ¤

__call__(
    embeddings: Float[Array, "n_cells latent_dim"],
    batch_labels: Int[Array, n_cells],
) -> Float[Array, ""]

Compute batch mixing loss.

Parameters:

Name Type Description Default
embeddings Float[Array, 'n_cells latent_dim']

Cell embeddings in latent space.

required
batch_labels Int[Array, n_cells]

Integer batch label for each cell.

required

Returns:

Type Description
Float[Array, '']

Negative mean entropy of batch distribution in neighborhoods (scalar).

Float[Array, '']

Lower loss means better mixing.

ClusteringCompactnessLoss¤

diffbio.losses.singlecell_losses.ClusteringCompactnessLoss ¤

ClusteringCompactnessLoss(
    separation_weight: float = 1.0,
    min_separation: float = 1.0,
    *,
    rngs: Rngs | None = None,
)

Bases: Module

Loss function to encourage compact and well-separated clusters.

Combines two components: 1. Compactness: Minimize within-cluster variance 2. Separation: Maximize between-cluster distances

Works with soft cluster assignments for end-to-end differentiability.

Parameters:

Name Type Description Default
separation_weight float

Weight for the separation term.

1.0
min_separation float

Minimum desired distance between cluster centers.

1.0
rngs Rngs | None

Flax NNX random number generators.

None
Example
loss_fn = ClusteringCompactnessLoss(rngs=nnx.Rngs(42))
loss = loss_fn(embeddings, soft_assignments)

Parameters:

Name Type Description Default
separation_weight float

Weight for separation term.

1.0
min_separation float

Minimum desired distance between centroids.

1.0
rngs Rngs | None

Random number generators (not used, for API consistency).

None

__call__ ¤

__call__(
    embeddings: Float[Array, "n_cells latent_dim"],
    assignments: Float[Array, "n_cells n_clusters"],
    centroids: Float[Array, "n_clusters latent_dim"]
    | None = None,
) -> Float[Array, ""]

Compute clustering compactness loss.

Parameters:

Name Type Description Default
embeddings Float[Array, 'n_cells latent_dim']

Cell embeddings in latent space.

required
assignments Float[Array, 'n_cells n_clusters']

Soft cluster assignments (should sum to 1 per cell).

required
centroids Float[Array, 'n_clusters latent_dim'] | None

Optional cluster centroids. If provided, uses these directly for gradient flow. If None, computes soft centroids from assignments.

None

Returns:

Type Description
Float[Array, '']

Combined compactness and separation loss (scalar).

VelocityConsistencyLoss¤

diffbio.losses.singlecell_losses.VelocityConsistencyLoss ¤

VelocityConsistencyLoss(
    dt: float = 0.1,
    cosine_weight: float = 1.0,
    magnitude_weight: float = 1.0,
    *,
    rngs: Rngs | None = None,
)

Bases: Module

Loss function to enforce consistency between velocity and trajectory.

Ensures that the predicted RNA velocity is consistent with actual expression changes over time. Combines directional (cosine) and magnitude consistency.

Parameters:

Name Type Description Default
dt float

Time step for velocity extrapolation.

0.1
cosine_weight float

Weight for directional consistency.

1.0
magnitude_weight float

Weight for magnitude consistency.

1.0
rngs Rngs | None

Flax NNX random number generators.

None
Example
loss_fn = VelocityConsistencyLoss(rngs=nnx.Rngs(42))
loss = loss_fn(expression, velocity, future_expression)

Parameters:

Name Type Description Default
dt float

Time step for velocity extrapolation.

0.1
cosine_weight float

Weight for cosine similarity loss.

1.0
magnitude_weight float

Weight for magnitude loss.

1.0
rngs Rngs | None

Random number generators (not used, for API consistency).

None

__call__ ¤

__call__(
    expression: Float[Array, "n_cells n_genes"],
    velocity: Float[Array, "n_cells n_genes"],
    future_expression: Float[Array, "n_cells n_genes"],
) -> Float[Array, ""]

Compute velocity consistency loss.

Parameters:

Name Type Description Default
expression Float[Array, 'n_cells n_genes']

Current gene expression.

required
velocity Float[Array, 'n_cells n_genes']

Predicted RNA velocity (rate of change).

required
future_expression Float[Array, 'n_cells n_genes']

Future gene expression (ground truth or estimated).

required

Returns:

Type Description
Float[Array, '']

Combined directional and magnitude consistency loss (scalar).

ShannonDiversityLoss¤

diffbio.losses.singlecell_losses.ShannonDiversityLoss ¤

ShannonDiversityLoss()

Bases: Module

Mean Shannon entropy of soft cluster assignments across cells.

Measures assignment diversity using Shannon entropy. Higher values indicate more uniform (diverse) cluster assignments, while lower values indicate concentrated assignments.

Delegates to calibrax.metrics.functional.information.entropy for the per-cell entropy computation.

Example
loss_fn = ShannonDiversityLoss()
# Soft cluster probabilities: (n_cells, n_clusters)
assignments = jax.nn.softmax(logits, axis=-1)
diversity = loss_fn(assignments)  # scalar, higher = more diverse

__call__ ¤

__call__(
    assignments: Float[Array, "n_cells n_clusters"],
) -> Float[Array, ""]

Compute mean Shannon entropy of soft cluster assignments.

Parameters:

Name Type Description Default
assignments Float[Array, 'n_cells n_clusters']

Soft cluster probabilities of shape (n_cells, n_clusters). Each row should sum to 1.

required

Returns:

Type Description
Float[Array, '']

Mean Shannon entropy across cells (scalar). Range [0, log(K)]

Float[Array, '']

where K is the number of clusters.

SimpsonDiversityLoss¤

diffbio.losses.singlecell_losses.SimpsonDiversityLoss ¤

SimpsonDiversityLoss()

Bases: Module

Mean Simpson concentration index of soft cluster assignments.

Computes the sum of squared assignment probabilities per cell, averaged across all cells. Lower values indicate more diverse (uniform) assignments.

  • Uniform assignments over K clusters yield 1/K.
  • Fully concentrated (one-hot) assignments yield 1.0.
Example
loss_fn = SimpsonDiversityLoss()
assignments = jax.nn.softmax(logits, axis=-1)
concentration = loss_fn(assignments)  # scalar, lower = more diverse

__call__ ¤

__call__(
    assignments: Float[Array, "n_cells n_clusters"],
) -> Float[Array, ""]

Compute mean Simpson concentration index.

Parameters:

Name Type Description Default
assignments Float[Array, 'n_cells n_clusters']

Soft cluster probabilities of shape (n_cells, n_clusters). Each row should sum to 1.

required

Returns:

Type Description
Float[Array, '']

Mean sum-of-squared-probabilities across cells (scalar).

Float[Array, '']

Range [1/K, 1.0] where K is the number of clusters.

Usage Examples¤

Batch Mixing Loss¤

from diffbio.losses import BatchMixingLoss

batch_loss = BatchMixingLoss(n_neighbors=15, temperature=1.0)

# Maximize batch mixing in latent space
loss = batch_loss(
    embeddings=latent_embeddings,  # (n_cells, latent_dim)
    batch_labels=batch_labels,     # (n_cells,)
)

Clustering Compactness Loss¤

from diffbio.losses import ClusteringCompactnessLoss

cluster_loss = ClusteringCompactnessLoss(separation_weight=1.0, min_separation=1.0)

# Encourage tight clusters
loss = cluster_loss(
    embeddings=cell_embeddings,
    assignments=soft_assignments,
)

Combined Training¤

def combined_loss(model, data):
    result, _, _ = model.apply(data, {}, None)

    # Batch mixing (negate to maximize)
    l_batch = -batch_loss(result["embeddings"], data["batch_ids"])

    # Clustering compactness
    l_cluster = cluster_loss(result["embeddings"], result["assignments"])

    return l_batch + 0.5 * l_cluster