Skip to content

Single-Cell Pipeline API¤

End-to-end differentiable single-cell RNA-seq analysis pipeline with scVI-style VAE, Harmony batch correction, and soft clustering.

SingleCellPipeline¤

diffbio.pipelines.single_cell.SingleCellPipeline ¤

SingleCellPipeline(
    config: SingleCellPipelineConfig,
    *,
    rngs: Rngs,
    name: str | None = None,
)

Bases: OperatorModule

End-to-end differentiable single-cell analysis pipeline.

This pipeline processes single-cell RNA-seq data through multiple analysis steps:

Input data structure
  • counts: Float[Array, "n_cells n_genes"] - Raw count matrix
  • ambient_profile: Float[Array, "n_genes"] - Ambient expression profile
  • batch_labels: Int[Array, "n_cells"] - Batch assignments

Output data structure (adds): - decontaminated_counts: Ambient-removed counts (if enabled) - normalized: VAE-normalized expression - latent: Latent space representation - corrected_embeddings: Batch-corrected embeddings (if enabled) - embeddings_2d: 2D UMAP embeddings (if enabled) - cluster_assignments: Soft cluster assignments

The pipeline is fully differentiable, supporting gradient-based training to optimize all components jointly for tasks like: - Supervised cell type classification - Semi-supervised clustering - Multi-task learning across batches

Example
config = SingleCellPipelineConfig(n_genes=2000, n_clusters=10)
pipeline = SingleCellPipeline(config, rngs=nnx.Rngs(42))
result, state, meta = pipeline.apply(data, {}, None)
clusters = result["cluster_assignments"]

Parameters:

Name Type Description Default
config SingleCellPipelineConfig

Pipeline configuration.

required
rngs Rngs

Random number generators for parameter initialization.

required
name str | None

Optional name for the pipeline.

None

apply ¤

apply(
    data: dict[str, Array],
    state: dict[str, Any],
    metadata: dict[str, Any] | None,
    random_params: Any = None,
    stats: dict[str, Any] | None = None,
) -> tuple[
    dict[str, Array], dict[str, Any], dict[str, Any] | None
]

Apply the full single-cell analysis pipeline.

Parameters:

Name Type Description Default
data dict[str, Array]

Input data containing: - counts: Float[Array, "n_cells n_genes"] - ambient_profile: Float[Array, "n_genes"] - batch_labels: Int[Array, "n_cells"]

required
state dict[str, Any]

Element state (passed through).

required
metadata dict[str, Any] | None

Element metadata (passed through).

required
random_params Any

Random parameters for stochastic operations.

None
stats dict[str, Any] | None

Optional statistics dict.

None

Returns:

Type Description
dict[str, Array]

Tuple of (output_data, state, metadata) where output_data contains

dict[str, Any]

all input keys plus analysis outputs.

SingleCellPipelineConfig¤

diffbio.pipelines.single_cell.SingleCellPipelineConfig dataclass ¤

SingleCellPipelineConfig(
    n_genes: int = 2000,
    n_clusters: int = 10,
    latent_dim: int = 64,
    hidden_dims: tuple[int, ...] = (128, 64),
    umap_n_components: int = 2,
    batch_correction_clusters: int = 100,
    batch_correction_iterations: int = 10,
    clustering_temperature: float = 1.0,
    enable_ambient_removal: bool = True,
    enable_batch_correction: bool = True,
    enable_dim_reduction: bool = True,
    enable_clustering: bool = True,
)

Bases: OperatorConfig

Configuration for the single-cell analysis pipeline.

Attributes:

Name Type Description
n_genes int

Number of genes in the expression matrix.

n_clusters int

Number of clusters for soft k-means.

latent_dim int

Dimension of the VAE latent space.

hidden_dims tuple[int, ...]

Hidden layer dimensions for VAE.

umap_n_components int

Number of UMAP output dimensions.

batch_correction_clusters int

Number of clusters for Harmony.

batch_correction_iterations int

Number of Harmony iterations.

clustering_temperature float

Temperature for soft clustering.

enable_ambient_removal bool

Whether to enable ambient RNA removal.

enable_batch_correction bool

Whether to enable batch correction.

enable_dim_reduction bool

Whether to enable UMAP dimensionality reduction.

enable_clustering bool

Whether to enable soft clustering.

Factory Function¤

create_single_cell_pipeline¤

diffbio.pipelines.single_cell.create_single_cell_pipeline ¤

create_single_cell_pipeline(
    n_genes: int = 2000,
    n_clusters: int = 10,
    latent_dim: int = 64,
    umap_n_components: int = 2,
    enable_ambient_removal: bool = True,
    enable_batch_correction: bool = True,
    enable_dim_reduction: bool = True,
    enable_clustering: bool = True,
    seed: int = 42,
) -> SingleCellPipeline

Factory function to create a single-cell analysis pipeline.

Parameters:

Name Type Description Default
n_genes int

Number of genes in the expression matrix.

2000
n_clusters int

Number of clusters for soft k-means.

10
latent_dim int

Dimension of the VAE latent space.

64
umap_n_components int

Number of UMAP output dimensions.

2
enable_ambient_removal bool

Whether to enable ambient RNA removal.

True
enable_batch_correction bool

Whether to enable batch correction.

True
enable_dim_reduction bool

Whether to enable UMAP.

True
enable_clustering bool

Whether to enable soft clustering.

True
seed int

Random seed.

42

Returns:

Type Description
SingleCellPipeline

Configured SingleCellPipeline instance.

Usage Examples¤

Quick Start¤

from diffbio.pipelines import create_single_cell_pipeline
import jax
import jax.numpy as jnp

# Create pipeline
pipeline = create_single_cell_pipeline(
    n_genes=2000,
    n_clusters=10,
    latent_dim=64,
)

# Prepare data
n_cells = 100
n_genes = 2000

data = {
    "counts": jax.random.poisson(
        jax.random.PRNGKey(0), lam=5.0, shape=(n_cells, n_genes)
    ).astype(jnp.float32),
    "ambient_profile": jnp.ones(n_genes) / n_genes,
    "batch_labels": jax.random.randint(jax.random.PRNGKey(1), (n_cells,), 0, 3),
}

# Run pipeline
result, _, _ = pipeline.apply(data, {}, None)
clusters = jnp.argmax(result["cluster_assignments"], axis=-1)

Full Configuration¤

from diffbio.pipelines import SingleCellPipeline, SingleCellPipelineConfig
from flax import nnx

config = SingleCellPipelineConfig(
    n_genes=5000,
    n_clusters=15,
    latent_dim=128,
    hidden_dims=(256, 128),
    umap_n_components=2,
    batch_correction_clusters=50,
    batch_correction_iterations=20,
    clustering_temperature=0.5,
    enable_ambient_removal=True,
    enable_batch_correction=True,
    enable_dim_reduction=True,
    enable_clustering=True,
)

pipeline = SingleCellPipeline(config, rngs=nnx.Rngs(42))
# Note: this pipeline has no training-mode toggle; submodules manage their
# own dropout/training state when applicable.

Training Mode¤

# SingleCellPipeline does not expose train_mode/eval_mode toggles.
# Submodules that use dropout manage their own state during apply().
for batch in dataloader:
    loss = train_step(pipeline, batch)

Access Components¤

# Ambient removal (if enabled)
if pipeline.ambient_removal is not None:
    pipeline.ambient_removal

# VAE normalizer (always available)
pipeline.vae_normalizer

# Batch correction (if enabled)
if pipeline.batch_correction is not None:
    pipeline.batch_correction

# Dimensionality reduction (if enabled)
if pipeline.dim_reduction is not None:
    pipeline.dim_reduction

# Clustering (if enabled)
if pipeline.clustering is not None:
    pipeline.clustering
    pipeline.clustering.centroids  # Cluster centers

Minimal Pipeline¤

# Create pipeline with only VAE normalization
minimal_pipeline = create_single_cell_pipeline(
    n_genes=2000,
    enable_ambient_removal=False,
    enable_batch_correction=False,
    enable_dim_reduction=False,
    enable_clustering=False,
)

Input Specifications¤

Key Shape Description
counts (n_cells, n_genes) Raw count matrix
ambient_profile (n_genes,) Ambient expression profile (normalized)
batch_labels (n_cells,) Integer batch assignments

Output Specifications¤

Key Shape Description
counts (n_cells, n_genes) Original count matrix
ambient_profile (n_genes,) Original ambient profile
batch_labels (n_cells,) Original batch labels
decontaminated_counts (n_cells, n_genes) Ambient-removed counts*
normalized (n_cells, n_genes) VAE-normalized expression
latent (n_cells, latent_dim) Latent space representation
corrected_embeddings (n_cells, latent_dim) Batch-corrected embeddings
embeddings_2d (n_cells, umap_n_components) 2D UMAP embeddings
cluster_assignments (n_cells, n_clusters) Soft cluster probabilities

*Only present when enable_ambient_removal=True