Skip to content

Single-Cell Operators API¤

Differentiable operators for single-cell analysis including clustering, batch correction, and RNA velocity.

SoftKMeansClustering¤

diffbio.operators.singlecell.soft_clustering.SoftKMeansClustering ¤

SoftKMeansClustering(
    config: SoftClusteringConfig,
    *,
    rngs: Rngs | None = None,
    name: str | None = None,
)

Bases: TemperatureOperator

Differentiable soft k-means clustering.

This operator implements soft k-means with learnable cluster centroids. Instead of hard cluster assignments, cells are softly assigned to clusters using softmax over negative squared distances.

Algorithm: 1. Compute squared distances from cells to centroids 2. Apply softmax for soft assignments: P(k|x) = softmax(-||x - c_k||² / T) 3. Optionally update centroids based on weighted means

Inherits from TemperatureOperator to get:

  • _temperature property for temperature-controlled smoothing
  • soft_max() for logsumexp-based smooth maximum
  • soft_argmax() for soft position selection

Parameters:

Name Type Description Default
config SoftClusteringConfig

SoftClusteringConfig with model parameters.

required
rngs Rngs | None

Flax NNX random number generators.

None
name str | None

Optional operator name.

None
Example
config = SoftClusteringConfig(n_clusters=10, n_features=50)
clusterer = SoftKMeansClustering(config, rngs=nnx.Rngs(42))
data = {"embeddings": cell_embeddings}
result, state, meta = clusterer.apply(data, {}, None)

Parameters:

Name Type Description Default
config SoftClusteringConfig

Clustering configuration.

required
rngs Rngs | None

Random number generators for initialization.

None
name str | None

Optional operator name.

None

apply ¤

apply(
    data: PyTree,
    state: PyTree,
    metadata: dict[str, Any] | None,
    random_params: Any = None,
    stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]

Apply soft k-means clustering to cell embeddings.

Parameters:

Name Type Description Default
data PyTree

Dictionary containing: - "embeddings": Cell embeddings (n_cells, n_features)

required
state PyTree

Element state (passed through unchanged)

required
metadata dict[str, Any] | None

Element metadata (passed through unchanged)

required
random_params Any

Not used

None
stats dict[str, Any] | None

Not used

None

Returns:

Type Description
tuple[PyTree, PyTree, dict[str, Any] | None]

Tuple of (transformed_data, state, metadata): - transformed_data contains:

- "embeddings": Original embeddings
- "cluster_assignments": Soft assignment probabilities
- "cluster_labels": Hard cluster labels
- "centroids": Cluster centroid positions
  • state is passed through unchanged
  • metadata is passed through unchanged

SoftClusteringConfig¤

diffbio.operators.singlecell.soft_clustering.SoftClusteringConfig dataclass ¤

SoftClusteringConfig(
    n_clusters: int = 10,
    n_features: int = 50,
    temperature: float = 1.0,
    learnable_centroids: bool = True,
)

Bases: OperatorConfig

Configuration for SoftKMeansClustering.

Attributes:

Name Type Description
n_clusters int

Number of clusters.

n_features int

Dimensionality of input embeddings.

temperature float

Temperature for softmax (lower = sharper).

learnable_centroids bool

Whether centroids are learnable parameters.

DifferentiableHarmony¤

diffbio.operators.singlecell.batch_correction.DifferentiableHarmony ¤

DifferentiableHarmony(
    config: BatchCorrectionConfig,
    *,
    rngs: Rngs | None = None,
    name: str | None = None,
)

Bases: TemperatureOperator

Differentiable Harmony-style batch correction.

This operator implements iterative batch correction using soft clustering with batch-aware updates. The fixed number of iterations enables gradient flow through the entire correction process.

Algorithm: 1. Initialize cluster centroids from data 2. Soft assignment of cells to clusters 3. Compute batch-aware centroid corrections 4. Update cell embeddings toward corrected centroids 5. Repeat for n_iterations

Inherits from TemperatureOperator to get:

  • _temperature property for temperature-controlled smoothing
  • soft_max() for logsumexp-based smooth maximum
  • soft_argmax() for soft position selection

Parameters:

Name Type Description Default
config BatchCorrectionConfig

BatchCorrectionConfig with model parameters.

required
rngs Rngs | None

Flax NNX random number generators.

None
name str | None

Optional operator name.

None
Example
config = BatchCorrectionConfig(n_clusters=100, n_batches=3)
harmony = DifferentiableHarmony(config, rngs=nnx.Rngs(42))
data = {"embeddings": X, "batch_labels": batch}
result, state, meta = harmony.apply(data, {}, None)

Parameters:

Name Type Description Default
config BatchCorrectionConfig

Batch correction configuration.

required
rngs Rngs | None

Random number generators for initialization.

None
name str | None

Optional operator name.

None

apply ¤

apply(
    data: PyTree,
    state: PyTree,
    metadata: dict[str, Any] | None,
    random_params: Any = None,
    stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]

Apply batch correction to cell embeddings.

Parameters:

Name Type Description Default
data PyTree

Dictionary containing: - "embeddings": Cell embeddings (n_cells, n_features) - "batch_labels": Batch assignments (n_cells,)

required
state PyTree

Element state (passed through unchanged)

required
metadata dict[str, Any] | None

Element metadata (passed through unchanged)

required
random_params Any

Not used

None
stats dict[str, Any] | None

Not used

None

Returns:

Type Description
tuple[PyTree, PyTree, dict[str, Any] | None]

Tuple of (transformed_data, state, metadata): - transformed_data contains:

- "embeddings": Original embeddings
- "batch_labels": Original batch labels
- "corrected_embeddings": Batch-corrected embeddings
- "cluster_assignments": Final soft cluster assignments
  • state is passed through unchanged
  • metadata is passed through unchanged

BatchCorrectionConfig¤

diffbio.operators.singlecell.batch_correction.BatchCorrectionConfig dataclass ¤

BatchCorrectionConfig(
    n_clusters: int = 100,
    n_features: int = 50,
    n_batches: int = 2,
    n_iterations: int = 10,
    theta: float = 2.0,
    sigma: float = 0.1,
    temperature: float = 1.0,
)

Bases: OperatorConfig

Configuration for DifferentiableHarmony.

Attributes:

Name Type Description
n_clusters int

Number of clusters for soft assignment.

n_features int

Dimensionality of input embeddings.

n_batches int

Number of distinct batches.

n_iterations int

Number of correction iterations.

theta float

Diversity penalty parameter.

sigma float

Soft assignment bandwidth.

temperature float

Temperature for softmax operations.

DifferentiableVelocity¤

diffbio.operators.singlecell.velocity.DifferentiableVelocity ¤

DifferentiableVelocity(
    config: VelocityConfig,
    *,
    rngs: Rngs | None = None,
    name: str | None = None,
)

Bases: OperatorModule

Differentiable RNA velocity estimation via Neural ODEs.

This operator estimates RNA velocity from spliced and unspliced counts using learned kinetics parameters and differentiable ODE integration.

Algorithm: 1. Encode expression to latent time per cell 2. Learn per-gene kinetics (alpha, beta, gamma) 3. Compute velocity from splicing ODE: ds/dt = beta * u - gamma * s du/dt = alpha - beta * u 4. Integrate ODE using Euler method

Parameters:

Name Type Description Default
config VelocityConfig

VelocityConfig with model parameters.

required
rngs Rngs | None

Flax NNX random number generators.

None
name str | None

Optional operator name.

None
Example
config = VelocityConfig(n_genes=2000)
velocity = DifferentiableVelocity(config, rngs=nnx.Rngs(42))
data = {"spliced": spliced, "unspliced": unspliced}
result, state, meta = velocity.apply(data, {}, None)

Parameters:

Name Type Description Default
config VelocityConfig

Velocity configuration.

required
rngs Rngs | None

Random number generators for initialization.

None
name str | None

Optional operator name.

None

apply ¤

apply(
    data: PyTree,
    state: PyTree,
    metadata: dict[str, Any] | None,
    random_params: Any = None,
    stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]

Apply RNA velocity estimation.

Parameters:

Name Type Description Default
data PyTree

Dictionary containing: - "spliced": Spliced mRNA counts (n_cells, n_genes) - "unspliced": Unspliced mRNA counts (n_cells, n_genes)

required
state PyTree

Element state (passed through unchanged)

required
metadata dict[str, Any] | None

Element metadata (passed through unchanged)

required
random_params Any

Not used

None
stats dict[str, Any] | None

Not used

None

Returns:

Type Description
tuple[PyTree, PyTree, dict[str, Any] | None]

Tuple of (transformed_data, state, metadata): - transformed_data contains:

- "spliced": Original spliced counts
- "unspliced": Original unspliced counts
- "velocity": RNA velocity estimates
- "latent_time": Estimated latent time per cell
- "alpha": Transcription rate per gene
- "beta": Splicing rate per gene
- "gamma": Degradation rate per gene
- "projected_spliced": Projected future spliced
  • state is passed through unchanged
  • metadata is passed through unchanged

VelocityConfig¤

diffbio.operators.singlecell.velocity.VelocityConfig dataclass ¤

VelocityConfig(
    n_genes: int = 2000,
    hidden_dim: int = 64,
    dt: float = 0.1,
    n_steps: int = 10,
    kinetics_model: str = "standard",
)

Bases: OperatorConfig

Configuration for DifferentiableVelocity.

Attributes:

Name Type Description
n_genes int

Number of genes.

hidden_dim int

Hidden dimension for neural networks.

dt float

Time step for ODE integration.

n_steps int

Number of integration steps.

kinetics_model str

Type of kinetics model ("standard" or "dynamical").

DifferentiableAmbientRemoval¤

diffbio.operators.singlecell.ambient_removal.DifferentiableAmbientRemoval ¤

DifferentiableAmbientRemoval(
    config: AmbientRemovalConfig,
    *,
    rngs: Rngs | None = None,
    name: str | None = None,
)

Bases: EncoderDecoderOperator

Differentiable ambient RNA removal using VAE.

This operator removes ambient RNA contamination from single-cell count data using a variational autoencoder that models both cell-intrinsic expression and ambient contamination.

Algorithm: 1. Encode counts to latent space + contamination fraction 2. Sample latent (reparameterization trick) 3. Decode to cell-intrinsic expression rate 4. Compute decontaminated counts by subtracting ambient contribution

Inherits from EncoderDecoderOperator to get:

  • reparameterize() for sampling with reparameterization trick
  • kl_divergence() for KL from standard normal
  • elbo_loss() for combining reconstruction and KL losses

Parameters:

Name Type Description Default
config AmbientRemovalConfig

AmbientRemovalConfig with model parameters.

required
rngs Rngs | None

Flax NNX random number generators.

None
name str | None

Optional operator name.

None
Example
config = AmbientRemovalConfig(n_genes=2000)
remover = DifferentiableAmbientRemoval(config, rngs=nnx.Rngs(42))
data = {"counts": counts, "ambient_profile": ambient}
result, state, meta = remover.apply(data, {}, None)

Parameters:

Name Type Description Default
config AmbientRemovalConfig

Ambient removal configuration.

required
rngs Rngs | None

Random number generators for initialization.

None
name str | None

Optional operator name.

None

apply ¤

apply(
    data: PyTree,
    state: PyTree,
    metadata: dict[str, Any] | None,
    random_params: Any = None,
    stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]

Apply ambient RNA removal.

Parameters:

Name Type Description Default
data PyTree

Dictionary containing: - "counts": Raw count matrix (n_cells, n_genes) - "ambient_profile": Ambient expression profile (n_genes,)

required
state PyTree

Element state (passed through unchanged)

required
metadata dict[str, Any] | None

Element metadata (passed through unchanged)

required
random_params Any

Random key for stochastic sampling

None
stats dict[str, Any] | None

Not used

None

Returns:

Type Description
tuple[PyTree, PyTree, dict[str, Any] | None]

Tuple of (transformed_data, state, metadata): - transformed_data contains:

- "counts": Original counts
- "ambient_profile": Original ambient profile
- "decontaminated_counts": Decontaminated counts
- "contamination_fraction": Estimated contamination per cell
- "latent": Latent representation
- "latent_mean": Mean of latent distribution
- "latent_logvar": Log variance of latent distribution
- "reconstructed": Reconstructed expression
  • state is passed through unchanged
  • metadata is passed through unchanged

AmbientRemovalConfig¤

diffbio.operators.singlecell.ambient_removal.AmbientRemovalConfig dataclass ¤

AmbientRemovalConfig(
    n_genes: int = 2000,
    latent_dim: int = 64,
    hidden_dims: list[int] = (lambda: [256, 128])(),
    ambient_prior: float = 0.01,
    temperature: float = 1.0,
)

Bases: OperatorConfig

Configuration for DifferentiableAmbientRemoval.

Attributes:

Name Type Description
n_genes int

Number of genes in expression profiles.

latent_dim int

Dimension of latent space.

hidden_dims list[int]

Hidden layer dimensions for encoder/decoder.

ambient_prior float

Prior probability of ambient contamination.

temperature float

Temperature for softmax operations.

DifferentiableCellAnnotator¤

diffbio.operators.singlecell.cell_annotation.DifferentiableCellAnnotator ¤

DifferentiableCellAnnotator(
    config: CellAnnotatorConfig,
    *,
    rngs: Rngs | None = None,
    name: str | None = None,
)

Bases: CountReconstructionMixin, CountVAEBackboneMixin, EncoderDecoderOperator

Differentiable cell type annotator with three annotation modes.

Modes¤

celltypist (logistic regression on latent): Encode counts to a VAE latent, apply a linear classifier head, softmax.

cellassign (marker-gene likelihood): Given a binary marker matrix M, compute per-type Poisson log-likelihoods with learnable rate parameters, then softmax.

scanvi (semi-supervised VAE with type-conditioned prior): VAE encoder + classifier head with learnable per-type Gaussian priors in latent space. The KL divergence uses p(z|y) = N(mu_y, sigma_y) instead of the standard N(0, I), and is marginalised over predicted type probabilities for unlabelled cells.

All modes additionally produce a latent representation via a shared VAE encoder.

Inherits from EncoderDecoderOperator to get:

  • reparameterize() for the VAE sampling step
  • kl_divergence() for KL from standard normal
  • elbo_loss() for combining reconstruction and KL losses

Parameters:

Name Type Description Default
config CellAnnotatorConfig

CellAnnotatorConfig with model parameters.

required
rngs Rngs | None

Flax NNX random number generators.

None
name str | None

Optional operator name.

None
Example
config = CellAnnotatorConfig(
    annotation_mode="celltypist",
    n_cell_types=10,
    n_genes=2000,
    stochastic=True,
    stream_name="sample",
)
annotator = DifferentiableCellAnnotator(config, rngs=nnx.Rngs(42))
data = {"counts": counts}
result, state, meta = annotator.apply(data, {}, None)

Parameters:

Name Type Description Default
config CellAnnotatorConfig

Annotator configuration.

required
rngs Rngs | None

Random number generators for initialisation and sampling.

None
name str | None

Optional operator name.

None

apply ¤

apply(
    data: PyTree,
    state: PyTree,
    metadata: dict[str, Any] | None,
    random_params: Any = None,
    stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]

Annotate cells with type probabilities.

Parameters:

Name Type Description Default
data PyTree

Dictionary containing: - "counts": Gene expression counts (n, n_genes) - (cellassign) "marker_matrix": Binary (n_types, n_genes) - (scanvi) "known_labels": Integer labels (n_labeled,) - (scanvi) "label_indices": Batch indices (n_labeled,)

required
state PyTree

Element state (passed through unchanged).

required
metadata dict[str, Any] | None

Element metadata (passed through unchanged).

required
random_params Any

Not used.

None
stats dict[str, Any] | None

Not used.

None

Returns:

Type Description
PyTree

Tuple of (transformed_data, state, metadata) where

PyTree

transformed_data adds: - "cell_type_probabilities": (n, n_cell_types) - "cell_type_labels": (n,) argmax labels - "latent": (n, latent_dim)

CellAnnotatorConfig¤

diffbio.operators.singlecell.cell_annotation.CellAnnotatorConfig dataclass ¤

CellAnnotatorConfig(
    annotation_mode: Literal[
        "scanvi", "cellassign", "celltypist"
    ] = "celltypist",
    n_cell_types: int = 10,
    n_genes: int = 2000,
    latent_dim: int = 10,
    hidden_dims: list[int] = (lambda: [128, 64])(),
    marker_matrix_shape: tuple[int, int] | None = None,
    gene_likelihood: Literal["poisson", "zinb"] = "poisson",
)

Bases: OperatorConfig

Configuration for cell type annotation.

Attributes:

Name Type Description
annotation_mode Literal['scanvi', 'cellassign', 'celltypist']

Annotation strategy to use.

n_cell_types int

Number of cell types to classify.

n_genes int

Number of input genes.

latent_dim int

Latent-space dimensionality for VAE encoder.

hidden_dims list[int]

Hidden layer sizes for encoder and decoder.

marker_matrix_shape tuple[int, int] | None

Shape (n_types, n_genes) for cellassign mode.

gene_likelihood Literal['poisson', 'zinb']

Reconstruction likelihood for scanvi mode. "poisson" for standard Poisson NLL (default), "zinb" for Zero-Inflated Negative Binomial.

DifferentiableDiffusionImputer¤

diffbio.operators.singlecell.imputation.DifferentiableDiffusionImputer ¤

DifferentiableDiffusionImputer(
    config: DiffusionImputerConfig,
    *,
    rngs: Rngs | None = None,
    name: str | None = None,
)

Bases: OperatorModule

Differentiable MAGIC-style diffusion imputation.

Constructs a cell-cell affinity graph using an alpha-decaying kernel, symmetrizes it, builds a row-stochastic Markov matrix M = D^{-1} A, and computes M^t via repeated matrix multiplication for imputation. This avoids eigendecomposition (whose backward pass produces NaN when eigenvalues are near-degenerate) while remaining fully differentiable.

Algorithm
  1. Compute pairwise distances between cells
  2. Build alpha-decay affinity: K(i,j) = exp(-(d/sigma_i)^decay)
  3. Symmetrize the affinity via fuzzy set union
  4. Row-normalize to Markov matrix M = D^{-1} A
  5. Compute M^t via repeated matrix multiplication (t iterations)
  6. Impute: imputed = M^t @ counts

Parameters:

Name Type Description Default
config DiffusionImputerConfig

DiffusionImputerConfig with operator parameters.

required
rngs Rngs | None

Flax NNX random number generators (not used, kept for API).

None
name str | None

Optional operator name.

None
Example

config = DiffusionImputerConfig(n_neighbors=5, diffusion_t=3) imputer = DifferentiableDiffusionImputer(config, rngs=nnx.Rngs(0)) data = {"counts": jnp.ones((100, 2000))} result, state, meta = imputer.apply(data, {}, None) result["imputed_counts"].shape (100, 2000)

Parameters:

Name Type Description Default
config DiffusionImputerConfig

Imputation configuration.

required
rngs Rngs | None

Random number generators (unused, present for API consistency).

None
name str | None

Optional operator name.

None

apply ¤

apply(
    data: PyTree,
    state: PyTree,
    metadata: dict[str, Any] | None,
    random_params: Any = None,
    stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]

Apply diffusion imputation to single-cell count data.

Parameters:

Name Type Description Default
data PyTree

Dictionary containing: - "counts": Gene expression matrix (n_cells, n_genes)

required
state PyTree

Element state (passed through unchanged).

required
metadata dict[str, Any] | None

Element metadata (passed through unchanged).

required
random_params Any

Not used (deterministic operator).

None
stats dict[str, Any] | None

Not used.

None

Returns:

Type Description
tuple[PyTree, PyTree, dict[str, Any] | None]

Tuple of (transformed_data, state, metadata): - transformed_data contains:

- ``"counts"``: Original counts
- ``"imputed_counts"``: Diffusion-imputed counts
- ``"diffusion_operator"``: The M^t matrix
  • state is passed through unchanged
  • metadata is passed through unchanged

DiffusionImputerConfig¤

diffbio.operators.singlecell.imputation.DiffusionImputerConfig dataclass ¤

DiffusionImputerConfig(
    n_neighbors: int = 5,
    diffusion_t: int = 3,
    n_pca_components: int = 100,
    decay: float = 1.0,
    metric: str = "euclidean",
)

Bases: OperatorConfig

Configuration for MAGIC-style diffusion imputation.

Attributes:

Name Type Description
n_neighbors int

Number of neighbors for local bandwidth estimation.

diffusion_t int

Number of diffusion time steps (matrix power).

n_pca_components int

Number of PCA components (reserved for future use).

decay float

Exponent for the alpha-decaying kernel (MAGIC default is 1).

metric str

Distance metric, either "euclidean" or "cosine".

DifferentiableTransformerDenoiser¤

diffbio.operators.singlecell.imputation.DifferentiableTransformerDenoiser ¤

DifferentiableTransformerDenoiser(
    config: TransformerDenoiserConfig,
    *,
    rngs: Rngs | None = None,
    name: str | None = None,
)

Bases: MaskedGeneTransformerOperatorMixin, OperatorModule

Transformer-based gene denoiser for single-cell expression data.

Genes are treated as tokens in a sequence. For each cell the operator:

  1. Randomly masks mask_ratio fraction of genes (sets expression to 0).
  2. Projects gene IDs into embeddings via TransformerSequenceEncoder (token_embedding mode) and adds a learned projection of the expression value so the transformer receives both identity and magnitude.
  3. Passes the sequence through a transformer encoder to obtain contextualised gene representations.
  4. Predicts masked gene expression from context via a linear output head.
  5. Returns imputed counts where masked positions are replaced with predictions and unmasked positions are kept from the original input.

Each cell is processed independently via jax.vmap over the cell dimension.

Parameters:

Name Type Description Default
config TransformerDenoiserConfig

TransformerDenoiserConfig with operator parameters.

required
rngs Rngs | None

Flax NNX random number generators.

None
name str | None

Optional operator name.

None
Example

config = TransformerDenoiserConfig(n_genes=2000, hidden_dim=128) denoiser = DifferentiableTransformerDenoiser( ... config, rngs=nnx.Rngs(params=0, sample=1, dropout=2)) rp = denoiser.generate_random_params( ... jax.random.key(0), {"counts": (100, 2000)}) data = {"counts": counts, "gene_ids": jnp.arange(2000)} result, state, meta = denoiser.apply(data, {}, None, random_params=rp) result["imputed_counts"].shape (100, 2000)

Parameters:

Name Type Description Default
config TransformerDenoiserConfig

Denoiser configuration.

required
rngs Rngs | None

Random number generators for parameter initialisation.

None
name str | None

Optional operator name.

None

apply ¤

apply(
    data: PyTree,
    state: PyTree,
    metadata: dict[str, Any] | None,
    random_params: Any = None,
    stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]

Apply transformer denoising to single-cell count data.

Parameters:

Name Type Description Default
data PyTree

Dictionary containing: - "counts": Gene expression matrix (n_cells, n_genes) - "gene_ids": Integer gene IDs (n_genes,)

required
state PyTree

Element state (passed through unchanged).

required
metadata dict[str, Any] | None

Element metadata (passed through unchanged).

required
random_params Any

JAX random key for mask generation.

None
stats dict[str, Any] | None

Not used.

None

Returns:

Type Description
tuple[PyTree, PyTree, dict[str, Any] | None]

Tuple of (transformed_data, state, metadata): - transformed_data contains:

- All original keys from data
- ``"imputed_counts"``: Denoised expression ``(n_cells, n_genes)``
- ``"mask"``: Binary mask used ``(n_genes,)``
  • state is passed through unchanged
  • metadata is passed through unchanged

TransformerDenoiserConfig¤

diffbio.operators.singlecell.imputation.TransformerDenoiserConfig dataclass ¤

TransformerDenoiserConfig(
    n_genes: int = 2000,
    hidden_dim: int = 128,
    num_layers: int = 2,
    num_heads: int = 4,
    mask_ratio: float = 0.15,
    dropout_rate: float = 0.1,
)

Bases: MaskedGeneTransformerConfigBase

Configuration for transformer-based gene denoising.

The denoiser treats genes as tokens: each gene has an expression value and a gene ID. A random fraction of genes is masked (expression zeroed) and the transformer predicts the original expression from the unmasked context.

n_genes class-attribute instance-attribute ¤

n_genes: int = 2000

hidden_dim class-attribute instance-attribute ¤

hidden_dim: int = 128

num_layers class-attribute instance-attribute ¤

num_layers: int = 2

num_heads class-attribute instance-attribute ¤

num_heads: int = 4

mask_ratio class-attribute instance-attribute ¤

mask_ratio: float = 0.15

dropout_rate class-attribute instance-attribute ¤

dropout_rate: float = 0.1

__post_init__ ¤

__post_init__() -> None

Default masked-gene operators to sampled stochastic execution.

DifferentiableDoubletScorer¤

diffbio.operators.singlecell.doublet_detection.DifferentiableDoubletScorer ¤

DifferentiableDoubletScorer(
    config: DoubletScorerConfig,
    *,
    rngs: Rngs | None = None,
    name: str | None = None,
)

Bases: OperatorModule

Differentiable Scrublet-style doublet detection operator.

Detects doublets by generating synthetic doublet profiles from random cell pairs, embedding real and synthetic cells into PCA space, and scoring each real cell via the Bayesian k-NN likelihood ratio from Scrublet (Wolock et al., 2019).

Algorithm
  1. Generate n_cells * sim_doublet_ratio synthetic doublets
  2. Concatenate real and synthetic cells
  3. PCA-embed via truncated SVD
  4. Compute pairwise distances in PCA space
  5. Adjust k upward: k_adj = round(k * (1 + n_syn / n_cells))
  6. Count soft synthetic neighbors in each real cell's k-NN
  7. Compute Laplace-smoothed fraction q of synthetic neighbors
  8. Bayesian likelihood ratio: Ld = q * rho / r / denom
  9. Apply sigmoid threshold for predicted doublet calls

Parameters:

Name Type Description Default
config DoubletScorerConfig

DoubletScorerConfig with operator parameters.

required
rngs Rngs | None

Flax NNX random number generators.

None
name str | None

Optional operator name.

None
Example

config = DoubletScorerConfig(n_neighbors=30, n_pca_components=30, ... n_genes=2000) scorer = DifferentiableDoubletScorer(config, rngs=nnx.Rngs(0)) rng = jax.random.key(0) rp = scorer.generate_random_params(rng, {"counts": (500, 2000)}) result, state, meta = scorer.apply({"counts": counts}, {}, None, ... random_params=rp) result["doublet_scores"].shape (500,)

Parameters:

Name Type Description Default
config DoubletScorerConfig

Doublet scorer configuration.

required
rngs Rngs | None

Random number generators for stochastic operations.

None
name str | None

Optional operator name.

None

apply ¤

apply(
    data: PyTree,
    state: PyTree,
    metadata: dict[str, Any] | None,
    random_params: Any = None,
    stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]

Apply doublet detection to single-cell count data.

Implements Scrublet's Bayesian k-NN likelihood-ratio scoring:

  1. Generate n_cells * sim_doublet_ratio synthetic doublets
  2. Adjust k upward: k_adj = round(k * (1 + n_syn / n_cells))
  3. For each real cell, count synthetic neighbors in its k-NN
  4. Compute Laplace-smoothed fraction q = (syn_count + 1) / (k_adj + 2)
  5. Bayesian likelihood ratio: Ld = q * rho / r / (1 - rho - q*(1 - rho - rho/r))

Parameters:

Name Type Description Default
data PyTree

Dictionary containing: - "counts": Gene expression matrix (n_cells, n_genes)

required
state PyTree

Element state (passed through unchanged).

required
metadata dict[str, Any] | None

Element metadata (passed through unchanged).

required
random_params Any

JAX random key for synthetic doublet generation.

None
stats dict[str, Any] | None

Not used.

None

Returns:

Type Description
tuple[PyTree, PyTree, dict[str, Any] | None]

Tuple of (transformed_data, state, metadata): - transformed_data contains:

- ``"counts"``: Original counts
- ``"doublet_scores"``: Bayesian likelihood ratio per cell
- ``"predicted_doublets"``: Soft doublet predictions in [0, 1]
  • state is passed through unchanged
  • metadata is passed through unchanged

DoubletScorerConfig¤

diffbio.operators.singlecell.doublet_detection.DoubletScorerConfig dataclass ¤

DoubletScorerConfig(
    n_neighbors: int = 30,
    expected_doublet_rate: float = 0.06,
    sim_doublet_ratio: float = 2.0,
    n_pca_components: int = 30,
    n_genes: int = 2000,
    threshold_temperature: float = 10.0,
)

Bases: OperatorConfig

Configuration for Scrublet-style doublet detection.

Attributes:

Name Type Description
n_neighbors int

Base number of nearest neighbors for scoring (adjusted upward to account for synthetic pool size).

expected_doublet_rate float

Prior expected fraction of doublets (rho in the Bayesian likelihood ratio).

sim_doublet_ratio float

Ratio of synthetic doublets to real cells. Scrublet default is 2.0, meaning 2x as many synthetics as real cells.

n_pca_components int

Number of PCA components for embedding.

n_genes int

Number of genes in expression profiles.

threshold_temperature float

Temperature for sigmoid doublet thresholding.

DifferentiableSoloDetector¤

diffbio.operators.singlecell.doublet_detection.DifferentiableSoloDetector ¤

DifferentiableSoloDetector(
    config: SoloDetectorConfig,
    *,
    rngs: Rngs | None = None,
    name: str | None = None,
)

Bases: CountVAEBackboneMixin, EncoderDecoderOperator

Solo-style VAE doublet detector.

Detects doublets by encoding cells through a VAE, generating synthetic doublets in count space, then classifying real vs synthetic cells in the VAE latent space.

Algorithm
  1. Generate synthetic doublets by summing random cell pairs
  2. Concatenate real and synthetic counts
  3. Encode all cells through the VAE encoder to obtain (mean, logvar)
  4. Sample latent z via the reparameterization trick
  5. Run a binary classifier on real-cell latents
  6. Return doublet probabilities, labels, and latent representations
Architecture
  • Encoder: counts -> log1p -> hidden layers (ReLU) -> (mean, logvar)
  • Decoder: z -> hidden layers (ReLU) -> log_rate
  • Classifier: z -> Linear -> ReLU -> Linear -> sigmoid

Parameters:

Name Type Description Default
config SoloDetectorConfig

SoloDetectorConfig with model parameters.

required
rngs Rngs | None

Flax NNX random number generators.

None
name str | None

Optional operator name.

None
Example

config = SoloDetectorConfig(n_genes=2000, latent_dim=10) detector = DifferentiableSoloDetector(config, rngs=nnx.Rngs(42)) rng = jax.random.key(0) rp = detector.generate_random_params(rng, {"counts": (500, 2000)}) result, _, _ = detector.apply({"counts": counts}, {}, None, random_params=rp) result["doublet_probabilities"].shape (500,)

Parameters:

Name Type Description Default
config SoloDetectorConfig

Solo detector configuration.

required
rngs Rngs | None

Random number generators for initialization and sampling.

None
name str | None

Optional operator name.

None

apply ¤

apply(
    data: PyTree,
    state: PyTree,
    metadata: dict[str, Any] | None,
    random_params: Any = None,
    stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]

Apply Solo-style VAE doublet detection.

Steps
  1. Generate synthetic doublets from random cell pairs
  2. Concatenate real and synthetic counts
  3. Encode all cells to latent space (mean, logvar)
  4. Sample z via reparameterization trick
  5. Run classifier on real-cell latents
  6. Return probabilities, labels, and latent for real cells only

Parameters:

Name Type Description Default
data PyTree

Dictionary containing: - "counts": Gene expression matrix (n_cells, n_genes)

required
state PyTree

Element state (passed through unchanged).

required
metadata dict[str, Any] | None

Element metadata (passed through unchanged).

required
random_params Any

JAX random key for synthetic doublet generation.

None
stats dict[str, Any] | None

Not used.

None

Returns:

Type Description
tuple[PyTree, PyTree, dict[str, Any] | None]

Tuple of (transformed_data, state, metadata): - transformed_data contains:

- ``"counts"``: Original counts
- ``"doublet_probabilities"``: Per-cell doublet probability
- ``"doublet_labels"``: Soft binary labels (sigmoid-thresholded)
- ``"latent"``: Latent representations for real cells
  • state is passed through unchanged
  • metadata is passed through unchanged

SoloDetectorConfig¤

diffbio.operators.singlecell.doublet_detection.SoloDetectorConfig dataclass ¤

SoloDetectorConfig(
    n_genes: int = 2000,
    latent_dim: int = 10,
    hidden_dims: list[int] = (lambda: [128, 64])(),
    classifier_hidden_dim: int = 64,
    sim_doublet_ratio: float = 2.0,
)

Bases: OperatorConfig

Configuration for Solo-style VAE doublet detection.

Solo (Bernstein et al., Cell Systems 2020) trains a VAE on real cells, generates synthetic doublets, encodes both into latent space, and trains a binary classifier to distinguish singlets from doublets.

Attributes:

Name Type Description
n_genes int

Number of genes in expression profiles.

latent_dim int

Dimension of the VAE latent space.

hidden_dims list[int]

Hidden layer dimensions for encoder/decoder.

classifier_hidden_dim int

Hidden dimension for the latent-space classifier.

sim_doublet_ratio float

Ratio of synthetic doublets to real cells.

DifferentiableCellCommunication¤

diffbio.operators.singlecell.communication.DifferentiableCellCommunication ¤

DifferentiableCellCommunication(
    config: CellCommunicationConfig,
    *,
    rngs: Rngs | None = None,
    name: str | None = None,
)

Bases: GraphOperator

GNN-based differentiable cell-cell communication analysis.

Analyses inter-cellular signaling by applying GATv2 graph attention on a spatial cell graph whose edges carry ligand-receptor expression features.

Algorithm
  1. Build per-edge L-R expression features from counts and lr_pairs: for each edge (i, j) the feature vector is the concatenation of [L_expr[source], R_expr[target]] across all L-R pairs, projected to edge_features_dim.
  2. Project per-node gene expression to initial node embeddings.
  3. Apply stacked GATv2 layers (SpatialAttentionGNN) for message passing on the spatial cell graph.
  4. Decode node embeddings into per-node pathway activity and per-node communication scores via SignalingDecoder.

Inherits from GraphOperator to get:

  • scatter_aggregate() for message aggregation
  • global_pool() for graph-level pooling

Parameters:

Name Type Description Default
config CellCommunicationConfig

CellCommunicationConfig with model parameters.

required
rngs Rngs | None

Flax NNX random number generators.

None
name str | None

Optional operator name.

None
Example

config = CellCommunicationConfig(n_genes=50, n_lr_pairs=3, hidden_dim=32) op = DifferentiableCellCommunication(config, rngs=nnx.Rngs(0)) data = {"counts": counts, "spatial_graph": graph, "lr_pairs": pairs} result, state, meta = op.apply(data, {}, None) result["communication_scores"].shape (n_cells, 3)

Parameters:

Name Type Description Default
config CellCommunicationConfig

Cell communication configuration.

required
rngs Rngs | None

Random number generators for parameter initialization.

None
name str | None

Optional operator name.

None

apply ¤

apply(
    data: PyTree,
    state: PyTree,
    metadata: dict[str, Any] | None,
    random_params: Any = None,
    stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]

Apply GNN-based cell-cell communication analysis.

Parameters:

Name Type Description Default
data PyTree

Dictionary containing: - "counts": Gene expression matrix (n_cells, n_genes) - "spatial_graph": Edge indices (2, n_edges) where row 0 = source nodes, row 1 = target nodes - "lr_pairs": L-R pair gene indices (n_pairs, 2)

required
state PyTree

Element state (passed through unchanged).

required
metadata dict[str, Any] | None

Element metadata (passed through unchanged).

required
random_params Any

Not used (non-stochastic operator).

None
stats dict[str, Any] | None

Not used.

None

Returns:

Type Description
tuple[PyTree, PyTree, dict[str, Any] | None]

Tuple of (transformed_data, state, metadata): - transformed_data contains all original keys plus:

- ``"communication_scores"``: ``(n_cells, n_pairs)``
- ``"signaling_activity"``: ``(n_cells, n_pathways)``
- ``"niche_embeddings"``: ``(n_cells, hidden_dim)``
  • state is passed through unchanged
  • metadata is passed through unchanged

CellCommunicationConfig¤

diffbio.operators.singlecell.communication.CellCommunicationConfig dataclass ¤

CellCommunicationConfig(
    hidden_dim: int = 64,
    num_heads: int = 4,
    num_gnn_layers: int = 2,
    n_pathways: int = 20,
    dropout_rate: float = 0.1,
    n_genes: int = 2000,
    n_lr_pairs: int = 10,
    edge_features_dim: int = 8,
)

Bases: _CellCommunicationInputConfig, _CellCommunicationModelConfig, OperatorConfig

Configuration for GNN-based cell-cell communication analysis.

hidden_dim class-attribute instance-attribute ¤

hidden_dim: int = 64

num_heads class-attribute instance-attribute ¤

num_heads: int = 4

num_gnn_layers class-attribute instance-attribute ¤

num_gnn_layers: int = 2

n_pathways class-attribute instance-attribute ¤

n_pathways: int = 20

dropout_rate class-attribute instance-attribute ¤

dropout_rate: float = 0.1

n_genes class-attribute instance-attribute ¤

n_genes: int = 2000

n_lr_pairs class-attribute instance-attribute ¤

n_lr_pairs: int = 10

edge_features_dim class-attribute instance-attribute ¤

edge_features_dim: int = 8

DifferentiableLigandReceptor¤

diffbio.operators.singlecell.communication.DifferentiableLigandReceptor ¤

DifferentiableLigandReceptor(
    config: LRScoringConfig,
    *,
    rngs: Rngs | None = None,
    name: str | None = None,
)

Bases: TemperatureOperator

Differentiable ligand-receptor co-expression scoring operator.

Scores cell-cell communication by computing adjacency-weighted co-expression of ligand-receptor gene pairs. For each pair, the score at each receiver cell is the sum of sender ligand expression times receiver receptor expression, weighted by a fuzzy k-NN adjacency graph.

Algorithm
  1. Build a symmetric fuzzy k-NN adjacency from the count matrix using compute_pairwise_distances, compute_fuzzy_membership, and symmetrize_graph.
  2. For each L-R pair (ligand_idx, receptor_idx):
  3. score_i = sum_j(adjacency[i,j] * L[j] * R[i]) where L[j] is the sender's ligand expression and R[i] is the receiver's receptor expression.
  4. Compute analytical soft p-values via z-score comparison against an expected null distribution.

Inherits from TemperatureOperator to get:

  • _temperature property for temperature-controlled smoothing
  • soft_max() for logsumexp-based smooth maximum
  • soft_argmax() for soft position selection

Parameters:

Name Type Description Default
config LRScoringConfig

LRScoringConfig with operator parameters.

required
rngs Rngs | None

Flax NNX random number generators.

None
name str | None

Optional operator name.

None
Example

config = LRScoringConfig(n_neighbors=15) op = DifferentiableLigandReceptor(config, rngs=nnx.Rngs(0)) data = {"counts": counts, "lr_pairs": jnp.array([[0, 1]])} result, state, meta = op.apply(data, {}, None) result["lr_scores"].shape (n_cells, 1)

Parameters:

Name Type Description Default
config LRScoringConfig

L-R scoring configuration.

required
rngs Rngs | None

Random number generators for parameter initialization.

None
name str | None

Optional operator name.

None

apply ¤

apply(
    data: PyTree,
    state: PyTree,
    metadata: dict[str, Any] | None,
    random_params: Any = None,
    stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]

Apply ligand-receptor co-expression scoring.

Parameters:

Name Type Description Default
data PyTree

Dictionary containing: - "counts": Gene expression matrix (n_cells, n_genes) - "lr_pairs": L-R pair indices (n_pairs, 2) where each row is [ligand_gene_idx, receptor_gene_idx]

required
state PyTree

Element state (passed through unchanged).

required
metadata dict[str, Any] | None

Element metadata (passed through unchanged).

required
random_params Any

Not used (non-stochastic operator).

None
stats dict[str, Any] | None

Not used.

None

Returns:

Type Description
tuple[PyTree, PyTree, dict[str, Any] | None]

Tuple of (transformed_data, state, metadata): - transformed_data contains:

- all original data keys
- ``"lr_scores"``: Per-cell interaction scores ``(n_cells, n_pairs)``
- ``"lr_pvalues"``: Soft p-values per pair ``(n_pairs,)``
  • state is passed through unchanged
  • metadata is passed through unchanged

LRScoringConfig¤

diffbio.operators.singlecell.communication.LRScoringConfig dataclass ¤

LRScoringConfig(
    n_neighbors: int = 15,
    temperature: float = 1.0,
    learnable_temperature: bool = False,
    metric: str = "euclidean",
    kh: float = 0.5,
    hill_n: float = 1.0,
)

Bases: OperatorConfig

Configuration for ligand-receptor co-expression scoring.

Attributes:

Name Type Description
n_neighbors int

Number of nearest neighbors for k-NN graph.

temperature float

Temperature for soft p-value sigmoid.

learnable_temperature bool

Whether the temperature is a learnable parameter.

metric str

Distance metric for k-NN, either "euclidean" or "cosine".

kh float

Hill function half-maximal constant (CellChat default 0.5).

hill_n float

Hill function cooperativity coefficient (CellChat default 1.0).

DifferentiableGRN¤

diffbio.operators.singlecell.grn_inference.DifferentiableGRN ¤

DifferentiableGRN(
    config: GRNInferenceConfig,
    *,
    rngs: Rngs | None = None,
    name: str | None = None,
)

Bases: OperatorModule

Differentiable gene regulatory network inference operator.

Uses GATv2 graph attention on a TF-gene bipartite graph to infer regulatory strengths. Each TF is connected to every gene; the attention weight on each edge represents how strongly the TF regulates that gene.

This is a novel differentiable alternative to GENIE3's random forest feature importance scoring. The key insight is that in GENIE3, each gene's expression is predicted from TF expression, and feature importance measures regulatory strength. Here, GATv2 attention performs an analogous role: TF nodes attend to gene nodes, and the learned attention weights capture regulatory relationships.

Parameters:

Name Type Description Default
config GRNInferenceConfig

GRNInferenceConfig with model parameters.

required
rngs Rngs | None

Flax NNX random number generators.

None
name str | None

Optional operator name.

None
Example

config = GRNInferenceConfig(n_tfs=5, n_genes=20, hidden_dim=16) op = DifferentiableGRN(config, rngs=nnx.Rngs(0)) data = {"counts": counts, "tf_indices": jnp.arange(5)} result, state, meta = op.apply(data, {}, None) result["grn_matrix"].shape (5, 20)

Parameters:

Name Type Description Default
config GRNInferenceConfig

GRN inference configuration.

required
rngs Rngs | None

Random number generators for parameter initialization.

None
name str | None

Optional operator name.

None

apply ¤

apply(
    data: PyTree,
    state: PyTree,
    metadata: dict[str, Any] | None,
    random_params: Any = None,
    stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]

Apply differentiable GRN inference.

Parameters:

Name Type Description Default
data PyTree

Dictionary containing: - "counts": Gene expression matrix (n_cells, n_genes) - "tf_indices": Indices of TF genes (n_tfs,)

required
state PyTree

Element state (passed through unchanged).

required
metadata dict[str, Any] | None

Element metadata (passed through unchanged).

required
random_params Any

Not used (non-stochastic operator).

None
stats dict[str, Any] | None

Not used.

None

Returns:

Type Description
tuple[PyTree, PyTree, dict[str, Any] | None]

Tuple of (transformed_data, state, metadata): - transformed_data contains all original keys plus:

- ``"grn_matrix"``: Sparse regulatory matrix ``(n_tfs, n_genes)``
- ``"tf_activity"``: Per-cell TF activity ``(n_cells, n_tfs)``
  • state is passed through unchanged
  • metadata is passed through unchanged

GRNInferenceConfig¤

diffbio.operators.singlecell.grn_inference.GRNInferenceConfig dataclass ¤

GRNInferenceConfig(
    n_tfs: int = 50,
    n_genes: int = 2000,
    hidden_dim: int = 64,
    num_heads: int = 4,
    sparsity_temperature: float = 0.1,
    sparsity_lambda: float = 0.01,
)

Bases: OperatorConfig

Configuration for differentiable GRN inference.

Attributes:

Name Type Description
n_tfs int

Number of transcription factors.

n_genes int

Number of genes in the expression matrix.

hidden_dim int

Hidden dimension for GATv2 attention (must be divisible by num_heads).

num_heads int

Number of attention heads in the GATv2 layer.

sparsity_temperature float

Temperature for soft L1 sparsity gating. Lower values produce sharper thresholding toward zero.

sparsity_lambda float

L1 regularization weight (used by downstream loss functions, not directly by the operator).

DifferentiableMMDBatchCorrection¤

diffbio.operators.singlecell.enhanced_batch_correction.DifferentiableMMDBatchCorrection ¤

DifferentiableMMDBatchCorrection(
    config: MMDBatchCorrectionConfig,
    *,
    rngs: Rngs | None = None,
    name: str | None = None,
)

Bases: LossBalancingMixin, OperatorModule

Autoencoder batch correction with MMD regularisation.

Architecture

Encoder MLP maps gene expression to a latent representation, and a decoder MLP reconstructs the expression from that latent. The MMD loss penalises distributional mismatch between batches in latent space so the learned representation becomes batch-invariant.

Loss

reconstruction_mse + mmd(latent_batch_0, latent_batch_1, ...)

Parameters:

Name Type Description Default
config MMDBatchCorrectionConfig

MMDBatchCorrectionConfig with model hyper-parameters.

required
rngs Rngs | None

Flax NNX random number generators.

None
name str | None

Optional operator name.

None
Example

config = MMDBatchCorrectionConfig(n_genes=2000) op = DifferentiableMMDBatchCorrection(config, rngs=nnx.Rngs(0)) result, _, _ = op.apply(data, {}, None)

Parameters:

Name Type Description Default
config MMDBatchCorrectionConfig

Operator configuration.

required
rngs Rngs | None

Random number generators for weight initialisation.

None
name str | None

Optional operator name.

None

apply ¤

apply(
    data: PyTree,
    state: PyTree,
    metadata: dict[str, Any] | None,
    random_params: Any = None,
    stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]

Encode, decode, and compute MMD + reconstruction losses.

Parameters:

Name Type Description Default
data PyTree

Dictionary containing: - "expression": Gene expression matrix (n_cells, n_genes) - "batch_labels": Integer batch assignments (n_cells,)

required
state PyTree

Pipeline state (passed through unchanged).

required
metadata dict[str, Any] | None

Pipeline metadata (passed through unchanged).

required
random_params Any

Unused.

None
stats dict[str, Any] | None

Unused.

None

Returns:

Type Description
tuple[PyTree, PyTree, dict[str, Any] | None]

Tuple of (result, state, metadata) where result contains: - "expression": Original expression - "batch_labels": Original batch labels - "corrected_expression": Decoded (corrected) expression - "latent": Latent representation - "mmd_loss": Scalar MMD loss between batches - "reconstruction_loss": Scalar MSE reconstruction loss

MMDBatchCorrectionConfig¤

diffbio.operators.singlecell.enhanced_batch_correction.MMDBatchCorrectionConfig dataclass ¤

MMDBatchCorrectionConfig(
    n_genes: int = 2000,
    hidden_dim: int = 128,
    latent_dim: int = 64,
    kernel_bandwidth: float = 1.0,
    use_gradnorm: bool = False,
)

Bases: OperatorConfig

Configuration for MMD-based batch correction.

Attributes:

Name Type Description
n_genes int

Number of input genes (features).

hidden_dim int

Width of hidden layers in the autoencoder.

latent_dim int

Dimensionality of the latent space.

kernel_bandwidth float

Bandwidth for the RBF kernel in the MMD loss.

use_gradnorm bool

Whether to use GradNormBalancer for multi-task loss balancing between reconstruction and MMD losses.

DifferentiableWGANBatchCorrection¤

diffbio.operators.singlecell.enhanced_batch_correction.DifferentiableWGANBatchCorrection ¤

DifferentiableWGANBatchCorrection(
    config: WGANBatchCorrectionConfig,
    *,
    rngs: Rngs | None = None,
    name: str | None = None,
)

Bases: LossBalancingMixin, OperatorModule

Adversarial autoencoder batch correction with Wasserstein GAN loss.

Architecture

An encoder (generator) maps gene expression to a batch-invariant latent space and a decoder reconstructs the expression. A separate discriminator tries to predict the batch label from the latent representation. Gradient reversal (a la scGPT) ensures the encoder learns to fool the discriminator, yielding batch-invariant latents.

Losses
  • generator_loss: Wasserstein generator loss (encoder wants to fool the discriminator) plus reconstruction MSE.
  • discriminator_loss: Wasserstein discriminator/critic loss.

Parameters:

Name Type Description Default
config WGANBatchCorrectionConfig

WGANBatchCorrectionConfig with model hyper-parameters.

required
rngs Rngs | None

Flax NNX random number generators.

None
name str | None

Optional operator name.

None
Example

config = WGANBatchCorrectionConfig(n_genes=2000) op = DifferentiableWGANBatchCorrection(config, rngs=nnx.Rngs(0)) result, _, _ = op.apply(data, {}, None)

Parameters:

Name Type Description Default
config WGANBatchCorrectionConfig

Operator configuration.

required
rngs Rngs | None

Random number generators for weight initialisation.

None
name str | None

Optional operator name.

None

apply ¤

apply(
    data: PyTree,
    state: PyTree,
    metadata: dict[str, Any] | None,
    random_params: Any = None,
    stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]

Encode, decode, and compute adversarial + reconstruction losses.

The discriminator receives latent codes through a gradient reversal layer so that encoder gradients push toward batch invariance while discriminator gradients push toward better batch classification.

Parameters:

Name Type Description Default
data PyTree

Dictionary containing: - "expression": Gene expression matrix (n_cells, n_genes) - "batch_labels": Integer batch assignments (n_cells,)

required
state PyTree

Pipeline state (passed through unchanged).

required
metadata dict[str, Any] | None

Pipeline metadata (passed through unchanged).

required
random_params Any

Unused.

None
stats dict[str, Any] | None

Unused.

None

Returns:

Type Description
tuple[PyTree, PyTree, dict[str, Any] | None]

Tuple of (result, state, metadata) where result contains: - "expression": Original expression - "batch_labels": Original batch labels - "corrected_expression": Decoded (corrected) expression - "latent": Latent representation - "discriminator_scores": Per-cell critic scores - "generator_loss": Scalar Wasserstein generator loss - "discriminator_loss": Scalar Wasserstein discriminator loss

WGANBatchCorrectionConfig¤

diffbio.operators.singlecell.enhanced_batch_correction.WGANBatchCorrectionConfig dataclass ¤

WGANBatchCorrectionConfig(
    n_genes: int = 2000,
    hidden_dim: int = 128,
    latent_dim: int = 64,
    discriminator_hidden_dim: int = 64,
    use_gradnorm: bool = False,
)

Bases: OperatorConfig

Configuration for WGAN-based batch correction.

Attributes:

Name Type Description
n_genes int

Number of input genes (features).

hidden_dim int

Width of hidden layers in the generator autoencoder.

latent_dim int

Dimensionality of the latent space.

discriminator_hidden_dim int

Width of hidden layers in the discriminator.

use_gradnorm bool

Whether to use GradNormBalancer for multi-task loss balancing between generator and discriminator losses.

DifferentiableSwitchDE¤

diffbio.operators.singlecell.switch_de.DifferentiableSwitchDE ¤

DifferentiableSwitchDE(
    config: SwitchDEConfig,
    *,
    rngs: Rngs | None = None,
    name: str | None = None,
)

Bases: TemperatureOperator

Differentiable sigmoidal switch model for differential expression.

Models gene expression as a sigmoidal function of pseudotime: g(t) = amplitude * sigmoid((t - t_switch) / temperature) + baseline

Each gene has learnable parameters for switch time, amplitude, and baseline expression level. The switch score quantifies how strongly a gene switches, computed as the maximum sigmoid derivative scaled by amplitude.

Inherits from TemperatureOperator to get:

  • _temperature property for temperature-controlled smoothing
  • soft_max() for logsumexp-based smooth maximum
  • soft_argmax() for soft position selection

Parameters:

Name Type Description Default
config SwitchDEConfig

SwitchDEConfig with model parameters.

required
rngs Rngs | None

Flax NNX random number generators.

None
name str | None

Optional operator name.

None
Example
config = SwitchDEConfig(n_genes=2000, temperature=1.0)
op = DifferentiableSwitchDE(config, rngs=nnx.Rngs(42))
data = {"counts": counts, "pseudotime": pseudotime}
result, state, meta = op.apply(data, {}, None)

Parameters:

Name Type Description Default
config SwitchDEConfig

Switch DE configuration.

required
rngs Rngs | None

Random number generators for initialization.

None
name str | None

Optional operator name.

None

apply ¤

apply(
    data: PyTree,
    state: PyTree,
    metadata: dict[str, Any] | None,
    random_params: Any = None,
    stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]

Apply sigmoidal switch DE model to single-cell data.

Parameters:

Name Type Description Default
data PyTree

Dictionary containing: - "counts": Gene expression counts (n_cells, n_genes) - "pseudotime": Pseudotime values per cell (n_cells,)

required
state PyTree

Element state (passed through unchanged).

required
metadata dict[str, Any] | None

Element metadata (passed through unchanged).

required
random_params Any

Not used.

None
stats dict[str, Any] | None

Not used.

None

Returns:

Type Description
tuple[PyTree, PyTree, dict[str, Any] | None]

Tuple of (transformed_data, state, metadata): - transformed_data contains:

- "counts": Original expression counts
- "pseudotime": Original pseudotime
- "switch_times": Learned switch time per gene
- "switch_scores": Switch score per gene
- "predicted_expression": Predicted expression from model
  • state is passed through unchanged
  • metadata is passed through unchanged

SwitchDEConfig¤

diffbio.operators.singlecell.switch_de.SwitchDEConfig dataclass ¤

SwitchDEConfig(
    n_genes: int = 2000,
    temperature: float = 1.0,
    learnable_temperature: bool = False,
)

Bases: OperatorConfig

Configuration for sigmoidal switch differential expression.

Attributes:

Name Type Description
n_genes int

Number of genes to model.

temperature float

Temperature controlling sigmoid smoothness. Lower values produce sharper switch transitions.

learnable_temperature bool

Whether temperature is a learnable parameter.

DifferentiablePseudotime¤

diffbio.operators.singlecell.trajectory.DifferentiablePseudotime ¤

DifferentiablePseudotime(
    config: PseudotimeConfig,
    *,
    rngs: Rngs | None = None,
    name: str | None = None,
)

Bases: OperatorModule

Differentiable pseudotime computation via diffusion maps.

Constructs a k-NN affinity graph, builds a Markov transition matrix, and computes diffusion components through subspace iteration with QR orthogonalization. Pseudotime is defined as the Euclidean distance in diffusion-component space from the designated root cell.

Algorithm
  1. Compute pairwise distances between cells.
  2. Compute fuzzy membership with local bandwidth (k-th neighbor).
  3. Symmetrize the graph via fuzzy set union.
  4. Row-normalize to obtain a Markov transition matrix.
  5. Extract the top n_diffusion_components eigenvectors of the symmetrized transition matrix via subspace iteration (repeated matmul + QR), excluding the trivial eigenvalue 1.
  6. Weight eigenvectors by their Rayleigh-quotient eigenvalues to form diffusion components.
  7. Pseudotime = L2 distance from root cell in diffusion-component space.

Parameters:

Name Type Description Default
config PseudotimeConfig

PseudotimeConfig with operator parameters.

required
rngs Rngs | None

Flax NNX random number generators (unused, kept for API).

None
name str | None

Optional operator name.

None
Example

config = PseudotimeConfig(n_neighbors=15, n_diffusion_components=10) op = DifferentiablePseudotime(config) data = {"embeddings": jnp.ones((50, 20))} result, state, meta = op.apply(data, {}, None) result["pseudotime"].shape (50,)

Parameters:

Name Type Description Default
config PseudotimeConfig

Pseudotime configuration.

required
rngs Rngs | None

Random number generators (unused, present for API consistency).

None
name str | None

Optional operator name.

None

apply ¤

apply(
    data: PyTree,
    state: PyTree,
    metadata: dict[str, Any] | None,
    random_params: Any = None,
    stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]

Apply pseudotime computation to cell embeddings.

Parameters:

Name Type Description Default
data PyTree

Dictionary containing: - "embeddings": Cell embeddings (n_cells, n_features)

required
state PyTree

Element state (passed through unchanged).

required
metadata dict[str, Any] | None

Element metadata (passed through unchanged).

required
random_params Any

Not used (deterministic operator).

None
stats dict[str, Any] | None

Not used.

None

Returns:

Type Description
tuple[PyTree, PyTree, dict[str, Any] | None]

Tuple of (transformed_data, state, metadata): - transformed_data contains:

- ``"pseudotime"``: Pseudotime values ``(n_cells,)``
- ``"diffusion_components"``: Diffusion map coordinates
  ``(n_cells, n_diffusion_components)``
- ``"transition_matrix"``: Markov transition matrix
  ``(n_cells, n_cells)``
- All original data keys are preserved
  • state is passed through unchanged
  • metadata is passed through unchanged

PseudotimeConfig¤

diffbio.operators.singlecell.trajectory.PseudotimeConfig dataclass ¤

PseudotimeConfig(
    n_neighbors: int = 15,
    n_diffusion_components: int = 10,
    root_cell_index: int = 0,
    metric: str = "euclidean",
)

Bases: OperatorConfig

Configuration for pseudotime computation.

Attributes:

Name Type Description
n_neighbors int

Number of neighbors for k-NN graph construction.

n_diffusion_components int

Number of diffusion map components to retain.

root_cell_index int

Index of the root cell (pseudotime origin).

metric str

Distance metric, "euclidean" or "cosine".

DifferentiableFateProbability¤

diffbio.operators.singlecell.trajectory.DifferentiableFateProbability ¤

DifferentiableFateProbability(
    config: FateProbabilityConfig,
    *,
    rngs: Rngs | None = None,
    name: str | None = None,
)

Bases: OperatorModule

Differentiable fate probability estimation via absorption probabilities.

Given a Markov transition matrix and a set of terminal (absorbing) state indices, partitions cells into transient and absorbing sets and computes the probability that each transient cell will eventually reach each absorbing state.

Algorithm
  1. Partition states into transient (T) and absorbing (A).
  2. Extract sub-matrices Q = transition[T, T] and R = transition[T, A].
  3. Solve (I - Q) @ B = R for B (absorption probabilities).
  4. Assign probability 1 to each absorbing state for itself.

The linear solve jnp.linalg.solve is fully differentiable.

Parameters:

Name Type Description Default
config FateProbabilityConfig

FateProbabilityConfig with operator parameters.

required
rngs Rngs | None

Flax NNX random number generators (unused, kept for API).

None
name str | None

Optional operator name.

None
Example

config = FateProbabilityConfig(n_macrostates=2) op = DifferentiableFateProbability(config) data = {"transition_matrix": T, "terminal_states": jnp.array([18, 19])} result, state, meta = op.apply(data, {}, None) result["fate_probabilities"].shape (20, 2)

Parameters:

Name Type Description Default
config FateProbabilityConfig

Fate probability configuration.

required
rngs Rngs | None

Random number generators (unused, present for API consistency).

None
name str | None

Optional operator name.

None

apply ¤

apply(
    data: PyTree,
    state: PyTree,
    metadata: dict[str, Any] | None,
    random_params: Any = None,
    stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]

Apply fate probability estimation.

Parameters:

Name Type Description Default
data PyTree

Dictionary containing: - "transition_matrix": Markov transition matrix (n_cells, n_cells) - "terminal_states": Indices of terminal states (n_terminal,)

required
state PyTree

Element state (passed through unchanged).

required
metadata dict[str, Any] | None

Element metadata (passed through unchanged).

required
random_params Any

Not used (deterministic operator).

None
stats dict[str, Any] | None

Not used.

None

Returns:

Type Description
tuple[PyTree, PyTree, dict[str, Any] | None]

Tuple of (transformed_data, state, metadata): - transformed_data contains:

- ``"fate_probabilities"``: Absorption probabilities
  ``(n_cells, n_terminal)``
- ``"macrostates"``: Argmax fate assignment ``(n_cells,)``
- All original data keys are preserved
  • state is passed through unchanged
  • metadata is passed through unchanged

FateProbabilityConfig¤

diffbio.operators.singlecell.trajectory.FateProbabilityConfig dataclass ¤

FateProbabilityConfig(n_macrostates: int = 2)

Bases: OperatorConfig

Configuration for fate probability computation.

Attributes:

Name Type Description
n_macrostates int

Number of macrostates (terminal fates).

DifferentiableSpatialDomain¤

diffbio.operators.singlecell.spatial_domains.DifferentiableSpatialDomain ¤

DifferentiableSpatialDomain(
    config: SpatialDomainConfig,
    *,
    rngs: Rngs | None = None,
    name: str | None = None,
)

Bases: GraphOperator

STAGATE-inspired differentiable spatial domain identification.

Identifies spatial domains by combining gene expression with spatial coordinates through a graph attention autoencoder. The encoder uses dual-graph GATv2 attention (full + pruned k-NN graphs), and soft domain assignments are computed via learned prototypes with softmax.

Algorithm
  1. Build spatial k-NN graph from coordinates (full + pruned/mutual).
  2. Apply GATv2 encoder: counts -> spatial embeddings (dual-graph attention weighted by alpha).
  3. Decoder: reconstruct gene expression from embeddings (autoencoder).
  4. Soft domain assignment via softmax on learned domain prototypes.

Inherits from GraphOperator to get:

  • scatter_aggregate() for message aggregation
  • global_pool() for graph-level pooling

Parameters:

Name Type Description Default
config SpatialDomainConfig

SpatialDomainConfig with model parameters.

required
rngs Rngs | None

Flax NNX random number generators.

None
name str | None

Optional operator name.

None

Parameters:

Name Type Description Default
config SpatialDomainConfig

Spatial domain configuration.

required
rngs Rngs | None

Random number generators for parameter initialization.

None
name str | None

Optional operator name.

None

apply ¤

apply(
    data: PyTree,
    state: PyTree,
    metadata: dict[str, Any] | None,
    random_params: Any = None,
    stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]

Apply spatial domain identification to spatial transcriptomics data.

Parameters:

Name Type Description Default
data PyTree

Dictionary containing: - "counts": Gene expression matrix (n_cells, n_genes) - "spatial_coords": Spatial coordinates (n_cells, 2)

required
state PyTree

Element state (passed through unchanged).

required
metadata dict[str, Any] | None

Element metadata (passed through unchanged).

required
random_params Any

Not used (non-stochastic operator).

None
stats dict[str, Any] | None

Not used.

None

Returns:

Type Description
tuple[PyTree, PyTree, dict[str, Any] | None]

Tuple of (transformed_data, state, metadata): - transformed_data contains all original keys plus:

- ``"domain_assignments"``: Soft domain probabilities
  ``(n_cells, n_domains)``
- ``"spatial_embeddings"``: Latent embeddings
  ``(n_cells, hidden_dim)``
  • state is passed through unchanged
  • metadata is passed through unchanged

SpatialDomainConfig¤

diffbio.operators.singlecell.spatial_domains.SpatialDomainConfig dataclass ¤

SpatialDomainConfig(
    n_genes: int = 2000,
    hidden_dim: int = 64,
    num_heads: int = 4,
    n_domains: int = 7,
    alpha: float = 0.8,
    n_neighbors: int = 15,
)

Bases: OperatorConfig

Configuration for STAGATE-style spatial domain identification.

Attributes:

Name Type Description
n_genes int

Number of input genes.

hidden_dim int

Latent embedding dimension. Must be divisible by num_heads.

num_heads int

Number of GATv2 attention heads.

n_domains int

Number of spatial domains to identify.

alpha float

Weight for pruned graph in dual-graph attention (STAGATE default 0.8). At alpha=0, only the full k-NN graph is used. At alpha=1, only the pruned (mutual k-NN) graph is used.

n_neighbors int

Number of nearest neighbors for spatial k-NN graph.

DifferentiablePASTEAlignment¤

diffbio.operators.singlecell.spatial_domains.DifferentiablePASTEAlignment ¤

DifferentiablePASTEAlignment(
    config: PASTEAlignmentConfig,
    *,
    rngs: Rngs | None = None,
    name: str | None = None,
)

Bases: GraphOperator

PASTE-inspired differentiable spatial transcriptomics slice alignment.

Aligns two spatial transcriptomics slices by computing a fused cost that balances expression dissimilarity with spatial structure (Gromov-Wasserstein) and solving for the optimal transport plan via differentiable Sinkhorn.

Algorithm
  1. Compute expression dissimilarity between slices (Euclidean distance).
  2. Compute intra-slice spatial distance matrices.
  3. Compute Gromov-Wasserstein spatial cost that penalizes distortion of pairwise spatial relationships.
  4. Fuse costs: alpha * expression_cost + (1 - alpha) * spatial_GW_cost.
  5. Solve OT via SinkhornLayer for the differentiable transport plan.
  6. Align slice 2 coordinates using the transport plan.

Inherits from GraphOperator to get:

  • scatter_aggregate() for message aggregation
  • global_pool() for graph-level pooling

Parameters:

Name Type Description Default
config PASTEAlignmentConfig

PASTEAlignmentConfig with alignment parameters.

required
rngs Rngs | None

Flax NNX random number generators.

None
name str | None

Optional operator name.

None

Parameters:

Name Type Description Default
config PASTEAlignmentConfig

PASTE alignment configuration.

required
rngs Rngs | None

Random number generators.

None
name str | None

Optional operator name.

None

apply ¤

apply(
    data: PyTree,
    state: PyTree,
    metadata: dict[str, Any] | None,
    random_params: Any = None,
    stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]

Apply PASTE-style alignment between two spatial transcriptomics slices.

Parameters:

Name Type Description Default
data PyTree

Dictionary containing: - "slice1_counts": Expression matrix for slice 1 (n1, g) - "slice2_counts": Expression matrix for slice 2 (n2, g) - "slice1_coords": Spatial coordinates for slice 1 (n1, 2) - "slice2_coords": Spatial coordinates for slice 2 (n2, 2)

required
state PyTree

Element state (passed through unchanged).

required
metadata dict[str, Any] | None

Element metadata (passed through unchanged).

required
random_params Any

Not used (non-stochastic operator).

None
stats dict[str, Any] | None

Not used.

None

Returns:

Type Description
tuple[PyTree, PyTree, dict[str, Any] | None]

Tuple of (transformed_data, state, metadata): - transformed_data contains all original keys plus:

- ``"transport_plan"``: OT plan ``(n1, n2)``
- ``"aligned_coords"``: Aligned slice 2 coordinates ``(n2, 2)``
  • state is passed through unchanged
  • metadata is passed through unchanged

PASTEAlignmentConfig¤

diffbio.operators.singlecell.spatial_domains.PASTEAlignmentConfig dataclass ¤

PASTEAlignmentConfig(
    alpha: float = 0.1,
    sinkhorn_epsilon: float = 0.1,
    sinkhorn_iters: int = 100,
)

Bases: OperatorConfig

Configuration for PASTE-style spatial transcriptomics slice alignment.

Attributes:

Name Type Description
alpha float

Balance between expression dissimilarity (linear term) and spatial Gromov-Wasserstein cost (quadratic term). 0 = pure expression matching, 1 = pure spatial structure matching. PASTE default: 0.1.

sinkhorn_epsilon float

Entropy regularisation strength for the Sinkhorn optimal transport solver.

sinkhorn_iters int

Number of Sinkhorn iterations.

DifferentiableDifferentialDistribution¤

diffbio.operators.singlecell.differential_distribution.DifferentiableDifferentialDistribution ¤

DifferentiableDifferentialDistribution(
    config: DifferentialDistributionConfig,
    *,
    rngs: Rngs | None = None,
    name: str | None = None,
)

Bases: TemperatureOperator

Differentiable KS-test with learned pattern classification.

For each gene, this operator:

  1. Splits cells into two conditions based on binary condition labels.
  2. Computes a soft empirical CDF using sigmoid smoothing: soft_CDF(x, values) = mean(sigmoid((x - values) / temperature))
  3. Computes a soft KS statistic as the smooth maximum of |CDF_A(x) - CDF_B(x)| over evaluation points, using logsumexp.
  4. Extracts distributional features (mean shift, variance ratio, zero-proportion difference) and passes them through a learned linear head to classify each gene into one of the pattern categories (shift, scale, both, none).

Inherits from TemperatureOperator to get:

  • _temperature property for temperature-controlled smoothing
  • soft_max() for logsumexp-based smooth maximum

Parameters:

Name Type Description Default
config DifferentialDistributionConfig

DifferentialDistributionConfig with model parameters.

required
rngs Rngs | None

Flax NNX random number generators.

None
name str | None

Optional operator name.

None
Example
config = DifferentialDistributionConfig(n_genes=2000, temperature=1.0)
op = DifferentiableDifferentialDistribution(config, rngs=nnx.Rngs(42))
data = {"counts": counts, "condition_labels": labels}
result, state, meta = op.apply(data, {}, None)

Parameters:

Name Type Description Default
config DifferentialDistributionConfig

Differential distribution configuration.

required
rngs Rngs | None

Random number generators for parameter initialisation.

None
name str | None

Optional operator name.

None

apply ¤

apply(
    data: PyTree,
    state: PyTree,
    metadata: dict[str, Any] | None,
    random_params: Any = None,
    stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]

Apply differentiable differential distribution testing.

For each gene, computes a soft KS statistic and classifies the distributional difference pattern using a learned linear head.

Parameters:

Name Type Description Default
data PyTree

Dictionary containing: - "counts": Gene expression matrix (n_cells, n_genes) - "condition_labels": Binary condition labels (n_cells,)

required
state PyTree

Element state (passed through unchanged).

required
metadata dict[str, Any] | None

Element metadata (passed through unchanged).

required
random_params Any

Not used.

None
stats dict[str, Any] | None

Not used.

None

Returns:

Type Description
tuple[PyTree, PyTree, dict[str, Any] | None]

Tuple of (transformed_data, state, metadata): - transformed_data contains:

- "counts": Original expression counts
- "condition_labels": Original condition labels
- "ks_statistics": Soft KS statistic per gene (n_genes,)
- "pattern_logits": Pattern class logits (n_genes, n_patterns)
- "pattern_labels": Predicted pattern labels (n_genes,)
  • state is passed through unchanged
  • metadata is passed through unchanged

DifferentialDistributionConfig¤

diffbio.operators.singlecell.differential_distribution.DifferentialDistributionConfig dataclass ¤

DifferentialDistributionConfig(
    n_genes: int = 2000,
    temperature: float = 1.0,
    learnable_temperature: bool = False,
    n_pattern_classes: int = 4,
)

Bases: OperatorConfig

Configuration for differentiable differential distribution testing.

Attributes:

Name Type Description
n_genes int

Number of genes to analyse.

temperature float

Temperature controlling sigmoid smoothness in the soft CDF and logsumexp soft max. Lower values yield sharper approximations closer to the true KS statistic.

learnable_temperature bool

Whether temperature is a learnable parameter.

n_pattern_classes int

Number of distributional pattern categories. Default 4 corresponds to (shift, scale, both, none).

DifferentiableSimulator¤

diffbio.operators.singlecell.simulation.DifferentiableSimulator ¤

DifferentiableSimulator(
    config: SimulationConfig,
    *,
    rngs: Rngs | None = None,
    name: str | None = None,
)

Bases: OperatorModule

Splatter-style differentiable single-cell count simulator.

Generates realistic scRNA-seq count matrices following the Splatter generative model (Zappia et al., 2017), with all steps implemented as differentiable JAX operations.

Algorithm
  1. Gene means: softplus-transformed learnable logits, scaled by Gamma(shape, rate) random perturbation.
  2. Cell library sizes: LogNormal(lib_loc, lib_scale) sampling.
  3. Group assignments: cells divided evenly across groups, with learnable group logits enabling soft assignment.
  4. DE fold-changes: per-group per-gene LogNormal fold-changes, masked by a Bernoulli(de_prob) DE indicator.
  5. Cell means: lib_sizes * gene_means * group_fold_change * batch_effect.
  6. Batch effects: exp(learnable batch_shift) multiplicative scaling.
  7. Dropout: sigmoid-based keep probability as function of log(cell_means).
  8. Counts: cell_means * keep_prob (continuous relaxation of Poisson).

Parameters:

Name Type Description Default
config SimulationConfig

SimulationConfig with model parameters.

required
rngs Rngs | None

Flax NNX random number generators.

None
name str | None

Optional operator name.

None
Example

config = SimulationConfig(n_cells=100, n_genes=50, n_groups=2) sim = DifferentiableSimulator(config, rngs=nnx.Rngs(0, sample=1)) rng = jax.random.key(0) rp = sim.generate_random_params(rng, {}) result, state, meta = sim.apply({}, {}, None, random_params=rp) result["counts"].shape (100, 50)

Parameters:

Name Type Description Default
config SimulationConfig

Simulation configuration.

required
rngs Rngs | None

Random number generators for parameter initialization.

None
name str | None

Optional operator name.

None

apply ¤

apply(
    data: PyTree,
    state: PyTree,
    metadata: dict[str, Any] | None,
    random_params: Any = None,
    stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]

Simulate a single-cell count matrix.

Follows the Splatter generative model with all steps differentiable: gene means, library sizes, group DE, batch effects, dropout, and Poisson count generation (continuous relaxation).

Parameters:

Name Type Description Default
data PyTree

Input dictionary (may be empty; existing keys are preserved).

required
state PyTree

Element state (passed through unchanged).

required
metadata dict[str, Any] | None

Element metadata (passed through unchanged).

required
random_params Any

Dictionary of JAX random keys from generate_random_params.

None
stats dict[str, Any] | None

Not used.

None

Returns:

Type Description
tuple[PyTree, PyTree, dict[str, Any] | None]

Tuple of (output_data, state, metadata) where output_data contains: - All original data keys preserved. - "counts": Simulated count matrix (n_cells, n_genes). - "group_labels": Hard group assignments (n_cells,). - "batch_labels": Batch assignments (n_cells,). - "gene_means": Per-gene expression means (n_genes,). - "de_mask": Binary DE indicator (n_groups, n_genes).

SimulationConfig¤

diffbio.operators.singlecell.simulation.SimulationConfig dataclass ¤

SimulationConfig(
    dropout_mid: float = -1.0,
    dropout_shape: float = -0.5,
    mean_shape: float = 0.6,
    mean_rate: float = 0.3,
    lib_loc: float = 11.0,
    lib_scale: float = 0.2,
    de_prob: float = 0.1,
    de_fac_loc: float = 0.1,
    de_fac_scale: float = 0.4,
    n_cells: int = 500,
    n_genes: int = 200,
    n_groups: int = 3,
    n_batches: int = 1,
)

Bases: _SimulationSizeConfig, _SimulationDistributionConfig, _SimulationDropoutConfig, OperatorConfig

Configuration for DifferentiableSimulator.

dropout_mid class-attribute instance-attribute ¤

dropout_mid: float = -1.0

dropout_shape class-attribute instance-attribute ¤

dropout_shape: float = -0.5

mean_shape class-attribute instance-attribute ¤

mean_shape: float = 0.6

mean_rate class-attribute instance-attribute ¤

mean_rate: float = 0.3

lib_loc class-attribute instance-attribute ¤

lib_loc: float = 11.0

lib_scale class-attribute instance-attribute ¤

lib_scale: float = 0.2

de_prob class-attribute instance-attribute ¤

de_prob: float = 0.1

de_fac_loc class-attribute instance-attribute ¤

de_fac_loc: float = 0.1

de_fac_scale class-attribute instance-attribute ¤

de_fac_scale: float = 0.4

n_cells class-attribute instance-attribute ¤

n_cells: int = 500

n_genes class-attribute instance-attribute ¤

n_genes: int = 200

n_groups class-attribute instance-attribute ¤

n_groups: int = 3

n_batches class-attribute instance-attribute ¤

n_batches: int = 1

DifferentiableArchetypalAnalysis¤

diffbio.operators.singlecell.archetypes.DifferentiableArchetypalAnalysis ¤

DifferentiableArchetypalAnalysis(
    config: ArchetypalAnalysisConfig,
    *,
    rngs: Rngs | None = None,
    name: str | None = None,
)

Bases: TemperatureOperator

Differentiable archetypal analysis with softmax simplex constraints.

Each cell is encoded into archetype weight space via an MLP, then temperature-controlled softmax produces simplex weights. The reconstruction is the convex combination of learnable archetype prototypes, enabling end-to-end gradient-based optimisation.

Inherits from TemperatureOperator to get:

  • _temperature property for temperature-controlled smoothing
  • soft_max() for logsumexp-based smooth maximum

Parameters:

Name Type Description Default
config ArchetypalAnalysisConfig

ArchetypalAnalysisConfig with model parameters.

required
rngs Rngs | None

Flax NNX random number generators.

None
name str | None

Optional operator name.

None
Example
import jax.numpy as jnp
config = ArchetypalAnalysisConfig(n_genes=2000, n_archetypes=5)
op = DifferentiableArchetypalAnalysis(config, rngs=nnx.Rngs(0))
data = {"counts": jnp.ones((100, 2000))}
result, state, meta = op.apply(data, {}, None)

Parameters:

Name Type Description Default
config ArchetypalAnalysisConfig

Archetypal analysis configuration.

required
rngs Rngs | None

Random number generators for weight initialisation.

None
name str | None

Optional operator name.

None

apply ¤

apply(
    data: PyTree,
    state: PyTree,
    metadata: dict[str, Any] | None,
    random_params: Any = None,
    stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]

Apply archetypal analysis to a cell-by-gene count matrix.

Parameters:

Name Type Description Default
data PyTree

Dictionary containing: - "counts": Cell-by-gene matrix (n_cells, n_genes)

required
state PyTree

Element state (passed through unchanged).

required
metadata dict[str, Any] | None

Element metadata (passed through unchanged).

required
random_params Any

Not used.

None
stats dict[str, Any] | None

Not used.

None

Returns:

Type Description
PyTree

Tuple of (transformed_data, state, metadata) where

PyTree

transformed_data contains:

dict[str, Any] | None
  • "counts": Original count matrix
tuple[PyTree, PyTree, dict[str, Any] | None]
  • "archetype_weights": Simplex weights (n_cells, n_archetypes)
tuple[PyTree, PyTree, dict[str, Any] | None]
  • "archetypes": Archetype prototypes (n_archetypes, n_genes)
tuple[PyTree, PyTree, dict[str, Any] | None]
  • "reconstructed": Reconstructed counts (n_cells, n_genes)

ArchetypalAnalysisConfig¤

diffbio.operators.singlecell.archetypes.ArchetypalAnalysisConfig dataclass ¤

ArchetypalAnalysisConfig(
    n_genes: int = 2000,
    n_archetypes: int = 5,
    hidden_dim: int = 64,
    temperature: float = 1.0,
    learnable_temperature: bool = False,
)

Bases: OperatorConfig

Configuration for DifferentiableArchetypalAnalysis.

Attributes:

Name Type Description
n_genes int

Number of input genes (features per cell).

n_archetypes int

Number of archetype prototypes to learn.

hidden_dim int

Hidden dimension for the encoder MLP.

temperature float

Softmax temperature (lower = sharper assignments).

learnable_temperature bool

Whether temperature is a learnable parameter.

DifferentiableOTTrajectory¤

diffbio.operators.singlecell.ot_trajectory.DifferentiableOTTrajectory ¤

DifferentiableOTTrajectory(
    config: OTTrajectoryConfig,
    *,
    rngs: Rngs | None = None,
    name: str | None = None,
)

Bases: OperatorModule

Waddington-OT-style differentiable trajectory inference.

Computes an optimal-transport plan between cell populations at two timepoints, estimates per-cell growth (proliferation) rates, and interpolates an intermediate cell distribution.

Algorithm
  1. Build the squared-Euclidean cost matrix C[i,j] = ||x_i - y_j||^2 between cells at t1 and t2.
  2. Compute the entropy-regularised transport plan via SinkhornLayer.
  3. Derive growth rates from the transport plan: cells that transport to more targets in t2 have higher proliferation. Normalise so that mean(growth_rates) == 1.
  4. Interpolate an intermediate distribution at time s: x_interp = (1-s) * x_t1 + s * (T @ x_t2) / T.sum(axis=1)

Parameters:

Name Type Description Default
config OTTrajectoryConfig

OTTrajectoryConfig with operator parameters.

required
rngs Rngs | None

Flax NNX random number generators.

None
name str | None

Optional operator name.

None
Example

config = OTTrajectoryConfig(n_genes=100, sinkhorn_iters=50) op = DifferentiableOTTrajectory(config) data = { ... "counts_t1": jnp.ones((20, 100)), ... "counts_t2": jnp.ones((25, 100)), ... } result, state, meta = op.apply(data, {}, None) result["transport_plan"].shape (20, 25)

Parameters:

Name Type Description Default
config OTTrajectoryConfig

OT trajectory configuration.

required
rngs Rngs | None

Random number generators (for API consistency).

None
name str | None

Optional operator name.

None

apply ¤

apply(
    data: PyTree,
    state: PyTree,
    metadata: dict[str, Any] | None,
    random_params: Any = None,
    stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]

Apply OT-based trajectory inference to two-timepoint expression data.

Parameters:

Name Type Description Default
data PyTree

Dictionary containing: - "counts_t1": Expression matrix at timepoint 1 (n1, g) - "counts_t2": Expression matrix at timepoint 2 (n2, g)

required
state PyTree

Element state (passed through unchanged).

required
metadata dict[str, Any] | None

Element metadata (passed through unchanged).

required
random_params Any

Not used (non-stochastic operator).

None
stats dict[str, Any] | None

Not used.

None

Returns:

Type Description
tuple[PyTree, PyTree, dict[str, Any] | None]

Tuple of (transformed_data, state, metadata): - transformed_data contains all original keys plus:

- ``"transport_plan"``: OT plan ``(n1, n2)``
- ``"growth_rates"``: Per-cell growth rates ``(n1,)``
- ``"interpolated_counts"``: Interpolated expression
  at the configured midpoint ``(n1, g)``
  • state is passed through unchanged
  • metadata is passed through unchanged

OTTrajectoryConfig¤

diffbio.operators.singlecell.ot_trajectory.OTTrajectoryConfig dataclass ¤

OTTrajectoryConfig(
    n_genes: int = 200,
    sinkhorn_epsilon: float = 0.1,
    sinkhorn_iters: int = 100,
    growth_rate_regularization: float = 1.0,
    interpolation_time: float = 0.5,
)

Bases: OperatorConfig

Configuration for OT-based trajectory inference.

Attributes:

Name Type Description
n_genes int

Number of input genes per cell.

sinkhorn_epsilon float

Entropy regularisation strength for the Sinkhorn solver. Larger values produce smoother transport plans.

sinkhorn_iters int

Number of Sinkhorn iterations.

growth_rate_regularization float

Scaling factor applied to raw row-sums before normalisation. Higher values amplify growth-rate variation.

interpolation_time float

Fraction in (0, 1) at which to compute the interpolated cell distribution between t1 and t2.

Usage Examples¤

Soft K-Means Clustering¤

from flax import nnx
from diffbio.operators.singlecell import SoftKMeansClustering, SoftClusteringConfig

config = SoftClusteringConfig(n_clusters=10, n_features=50)
clustering = SoftKMeansClustering(config, rngs=nnx.Rngs(42))

data = {"embeddings": embeddings}  # (n_cells, n_embeddings)
result, _, _ = clustering.apply(data, {}, None)
assignments = result["cluster_assignments"]

Batch Correction¤

from diffbio.operators.singlecell import DifferentiableHarmony, BatchCorrectionConfig

config = BatchCorrectionConfig(n_clusters=50, n_features=50)
harmony = DifferentiableHarmony(config, rngs=nnx.Rngs(42))

data = {"embeddings": embeddings, "batch_ids": batch_labels}
result, _, _ = harmony.apply(data, {}, None)
corrected = result["corrected_embeddings"]

RNA Velocity¤

from diffbio.operators.singlecell import DifferentiableVelocity, VelocityConfig

config = VelocityConfig(n_genes=2000, hidden_dim=64)
velocity = DifferentiableVelocity(config, rngs=nnx.Rngs(42))

data = {"spliced": spliced, "unspliced": unspliced}
result, _, _ = velocity.apply(data, {}, None)
vel = result["velocity"]