Single-Cell Pipeline API¤
End-to-end differentiable single-cell RNA-seq analysis pipeline with scVI-style VAE, Harmony batch correction, and soft clustering.
SingleCellPipeline¤
diffbio.pipelines.single_cell.SingleCellPipeline
¤
SingleCellPipeline(
config: SingleCellPipelineConfig,
*,
rngs: Rngs,
name: str | None = None,
)
Bases: OperatorModule
End-to-end differentiable single-cell analysis pipeline.
This pipeline processes single-cell RNA-seq data through multiple analysis steps:
Input data structure
- counts: Float[Array, "n_cells n_genes"] - Raw count matrix
- ambient_profile: Float[Array, "n_genes"] - Ambient expression profile
- batch_labels: Int[Array, "n_cells"] - Batch assignments
Output data structure (adds): - decontaminated_counts: Ambient-removed counts (if enabled) - normalized: VAE-normalized expression - latent: Latent space representation - corrected_embeddings: Batch-corrected embeddings (if enabled) - embeddings_2d: 2D UMAP embeddings (if enabled) - cluster_assignments: Soft cluster assignments
The pipeline is fully differentiable, supporting gradient-based training to optimize all components jointly for tasks like: - Supervised cell type classification - Semi-supervised clustering - Multi-task learning across batches
Example
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
SingleCellPipelineConfig
|
Pipeline configuration. |
required |
rngs
|
Rngs
|
Random number generators for parameter initialization. |
required |
name
|
str | None
|
Optional name for the pipeline. |
None
|
apply
¤
apply(
data: dict[str, Array],
state: dict[str, Any],
metadata: dict[str, Any] | None,
random_params: Any = None,
stats: dict[str, Any] | None = None,
) -> tuple[
dict[str, Array], dict[str, Any], dict[str, Any] | None
]
Apply the full single-cell analysis pipeline.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
dict[str, Array]
|
Input data containing: - counts: Float[Array, "n_cells n_genes"] - ambient_profile: Float[Array, "n_genes"] - batch_labels: Int[Array, "n_cells"] |
required |
state
|
dict[str, Any]
|
Element state (passed through). |
required |
metadata
|
dict[str, Any] | None
|
Element metadata (passed through). |
required |
random_params
|
Any
|
Random parameters for stochastic operations. |
None
|
stats
|
dict[str, Any] | None
|
Optional statistics dict. |
None
|
Returns:
| Type | Description |
|---|---|
dict[str, Array]
|
Tuple of (output_data, state, metadata) where output_data contains |
dict[str, Any]
|
all input keys plus analysis outputs. |
SingleCellPipelineConfig¤
diffbio.pipelines.single_cell.SingleCellPipelineConfig
dataclass
¤
SingleCellPipelineConfig(
n_genes: int = 2000,
n_clusters: int = 10,
latent_dim: int = 64,
hidden_dims: tuple[int, ...] = (128, 64),
umap_n_components: int = 2,
batch_correction_clusters: int = 100,
batch_correction_iterations: int = 10,
clustering_temperature: float = 1.0,
enable_ambient_removal: bool = True,
enable_batch_correction: bool = True,
enable_dim_reduction: bool = True,
enable_clustering: bool = True,
)
Bases: OperatorConfig
Configuration for the single-cell analysis pipeline.
Attributes:
| Name | Type | Description |
|---|---|---|
n_genes |
int
|
Number of genes in the expression matrix. |
n_clusters |
int
|
Number of clusters for soft k-means. |
latent_dim |
int
|
Dimension of the VAE latent space. |
hidden_dims |
tuple[int, ...]
|
Hidden layer dimensions for VAE. |
umap_n_components |
int
|
Number of UMAP output dimensions. |
batch_correction_clusters |
int
|
Number of clusters for Harmony. |
batch_correction_iterations |
int
|
Number of Harmony iterations. |
clustering_temperature |
float
|
Temperature for soft clustering. |
enable_ambient_removal |
bool
|
Whether to enable ambient RNA removal. |
enable_batch_correction |
bool
|
Whether to enable batch correction. |
enable_dim_reduction |
bool
|
Whether to enable UMAP dimensionality reduction. |
enable_clustering |
bool
|
Whether to enable soft clustering. |
Factory Function¤
create_single_cell_pipeline¤
diffbio.pipelines.single_cell.create_single_cell_pipeline
¤
create_single_cell_pipeline(
n_genes: int = 2000,
n_clusters: int = 10,
latent_dim: int = 64,
umap_n_components: int = 2,
enable_ambient_removal: bool = True,
enable_batch_correction: bool = True,
enable_dim_reduction: bool = True,
enable_clustering: bool = True,
seed: int = 42,
) -> SingleCellPipeline
Factory function to create a single-cell analysis pipeline.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_genes
|
int
|
Number of genes in the expression matrix. |
2000
|
n_clusters
|
int
|
Number of clusters for soft k-means. |
10
|
latent_dim
|
int
|
Dimension of the VAE latent space. |
64
|
umap_n_components
|
int
|
Number of UMAP output dimensions. |
2
|
enable_ambient_removal
|
bool
|
Whether to enable ambient RNA removal. |
True
|
enable_batch_correction
|
bool
|
Whether to enable batch correction. |
True
|
enable_dim_reduction
|
bool
|
Whether to enable UMAP. |
True
|
enable_clustering
|
bool
|
Whether to enable soft clustering. |
True
|
seed
|
int
|
Random seed. |
42
|
Returns:
| Type | Description |
|---|---|
SingleCellPipeline
|
Configured SingleCellPipeline instance. |
Usage Examples¤
Quick Start¤
from diffbio.pipelines import create_single_cell_pipeline
import jax
import jax.numpy as jnp
# Create pipeline
pipeline = create_single_cell_pipeline(
n_genes=2000,
n_clusters=10,
latent_dim=64,
)
# Prepare data
n_cells = 100
n_genes = 2000
data = {
"counts": jax.random.poisson(
jax.random.PRNGKey(0), lam=5.0, shape=(n_cells, n_genes)
).astype(jnp.float32),
"ambient_profile": jnp.ones(n_genes) / n_genes,
"batch_labels": jax.random.randint(jax.random.PRNGKey(1), (n_cells,), 0, 3),
}
# Run pipeline
result, _, _ = pipeline.apply(data, {}, None)
clusters = jnp.argmax(result["cluster_assignments"], axis=-1)
Full Configuration¤
from diffbio.pipelines import SingleCellPipeline, SingleCellPipelineConfig
from flax import nnx
config = SingleCellPipelineConfig(
n_genes=5000,
n_clusters=15,
latent_dim=128,
hidden_dims=(256, 128),
umap_n_components=2,
batch_correction_clusters=50,
batch_correction_iterations=20,
clustering_temperature=0.5,
enable_ambient_removal=True,
enable_batch_correction=True,
enable_dim_reduction=True,
enable_clustering=True,
)
pipeline = SingleCellPipeline(config, rngs=nnx.Rngs(42))
# Note: this pipeline has no training-mode toggle; submodules manage their
# own dropout/training state when applicable.
Training Mode¤
# SingleCellPipeline does not expose train_mode/eval_mode toggles.
# Submodules that use dropout manage their own state during apply().
for batch in dataloader:
loss = train_step(pipeline, batch)
Access Components¤
# Ambient removal (if enabled)
if pipeline.ambient_removal is not None:
pipeline.ambient_removal
# VAE normalizer (always available)
pipeline.vae_normalizer
# Batch correction (if enabled)
if pipeline.batch_correction is not None:
pipeline.batch_correction
# Dimensionality reduction (if enabled)
if pipeline.dim_reduction is not None:
pipeline.dim_reduction
# Clustering (if enabled)
if pipeline.clustering is not None:
pipeline.clustering
pipeline.clustering.centroids # Cluster centers
Minimal Pipeline¤
# Create pipeline with only VAE normalization
minimal_pipeline = create_single_cell_pipeline(
n_genes=2000,
enable_ambient_removal=False,
enable_batch_correction=False,
enable_dim_reduction=False,
enable_clustering=False,
)
Input Specifications¤
| Key | Shape | Description |
|---|---|---|
counts |
(n_cells, n_genes) | Raw count matrix |
ambient_profile |
(n_genes,) | Ambient expression profile (normalized) |
batch_labels |
(n_cells,) | Integer batch assignments |
Output Specifications¤
| Key | Shape | Description |
|---|---|---|
counts |
(n_cells, n_genes) | Original count matrix |
ambient_profile |
(n_genes,) | Original ambient profile |
batch_labels |
(n_cells,) | Original batch labels |
decontaminated_counts |
(n_cells, n_genes) | Ambient-removed counts* |
normalized |
(n_cells, n_genes) | VAE-normalized expression |
latent |
(n_cells, latent_dim) | Latent space representation |
corrected_embeddings |
(n_cells, latent_dim) | Batch-corrected embeddings |
embeddings_2d |
(n_cells, umap_n_components) | 2D UMAP embeddings |
cluster_assignments |
(n_cells, n_clusters) | Soft cluster probabilities |
*Only present when enable_ambient_removal=True