Single-Cell Operators API¤
Differentiable operators for single-cell analysis including clustering, batch correction, and RNA velocity.
SoftKMeansClustering¤
diffbio.operators.singlecell.soft_clustering.SoftKMeansClustering
¤
SoftKMeansClustering(
config: SoftClusteringConfig,
*,
rngs: Rngs | None = None,
name: str | None = None,
)
Bases: TemperatureOperator
Differentiable soft k-means clustering.
This operator implements soft k-means with learnable cluster centroids. Instead of hard cluster assignments, cells are softly assigned to clusters using softmax over negative squared distances.
Algorithm: 1. Compute squared distances from cells to centroids 2. Apply softmax for soft assignments: P(k|x) = softmax(-||x - c_k||² / T) 3. Optionally update centroids based on weighted means
Inherits from TemperatureOperator to get:
- _temperature property for temperature-controlled smoothing
- soft_max() for logsumexp-based smooth maximum
- soft_argmax() for soft position selection
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
SoftClusteringConfig
|
SoftClusteringConfig with model parameters. |
required |
rngs
|
Rngs | None
|
Flax NNX random number generators. |
None
|
name
|
str | None
|
Optional operator name. |
None
|
Example
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
SoftClusteringConfig
|
Clustering configuration. |
required |
rngs
|
Rngs | None
|
Random number generators for initialization. |
None
|
name
|
str | None
|
Optional operator name. |
None
|
apply
¤
apply(
data: PyTree,
state: PyTree,
metadata: dict[str, Any] | None,
random_params: Any = None,
stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]
Apply soft k-means clustering to cell embeddings.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
PyTree
|
Dictionary containing: - "embeddings": Cell embeddings (n_cells, n_features) |
required |
state
|
PyTree
|
Element state (passed through unchanged) |
required |
metadata
|
dict[str, Any] | None
|
Element metadata (passed through unchanged) |
required |
random_params
|
Any
|
Not used |
None
|
stats
|
dict[str, Any] | None
|
Not used |
None
|
Returns:
| Type | Description |
|---|---|
tuple[PyTree, PyTree, dict[str, Any] | None]
|
Tuple of (transformed_data, state, metadata): - transformed_data contains:
|
SoftClusteringConfig¤
diffbio.operators.singlecell.soft_clustering.SoftClusteringConfig
dataclass
¤
SoftClusteringConfig(
n_clusters: int = 10,
n_features: int = 50,
temperature: float = 1.0,
learnable_centroids: bool = True,
)
Bases: OperatorConfig
Configuration for SoftKMeansClustering.
Attributes:
| Name | Type | Description |
|---|---|---|
n_clusters |
int
|
Number of clusters. |
n_features |
int
|
Dimensionality of input embeddings. |
temperature |
float
|
Temperature for softmax (lower = sharper). |
learnable_centroids |
bool
|
Whether centroids are learnable parameters. |
DifferentiableHarmony¤
diffbio.operators.singlecell.batch_correction.DifferentiableHarmony
¤
DifferentiableHarmony(
config: BatchCorrectionConfig,
*,
rngs: Rngs | None = None,
name: str | None = None,
)
Bases: TemperatureOperator
Differentiable Harmony-style batch correction.
This operator implements iterative batch correction using soft clustering with batch-aware updates. The fixed number of iterations enables gradient flow through the entire correction process.
Algorithm: 1. Initialize cluster centroids from data 2. Soft assignment of cells to clusters 3. Compute batch-aware centroid corrections 4. Update cell embeddings toward corrected centroids 5. Repeat for n_iterations
Inherits from TemperatureOperator to get:
- _temperature property for temperature-controlled smoothing
- soft_max() for logsumexp-based smooth maximum
- soft_argmax() for soft position selection
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
BatchCorrectionConfig
|
BatchCorrectionConfig with model parameters. |
required |
rngs
|
Rngs | None
|
Flax NNX random number generators. |
None
|
name
|
str | None
|
Optional operator name. |
None
|
Example
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
BatchCorrectionConfig
|
Batch correction configuration. |
required |
rngs
|
Rngs | None
|
Random number generators for initialization. |
None
|
name
|
str | None
|
Optional operator name. |
None
|
apply
¤
apply(
data: PyTree,
state: PyTree,
metadata: dict[str, Any] | None,
random_params: Any = None,
stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]
Apply batch correction to cell embeddings.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
PyTree
|
Dictionary containing: - "embeddings": Cell embeddings (n_cells, n_features) - "batch_labels": Batch assignments (n_cells,) |
required |
state
|
PyTree
|
Element state (passed through unchanged) |
required |
metadata
|
dict[str, Any] | None
|
Element metadata (passed through unchanged) |
required |
random_params
|
Any
|
Not used |
None
|
stats
|
dict[str, Any] | None
|
Not used |
None
|
Returns:
| Type | Description |
|---|---|
tuple[PyTree, PyTree, dict[str, Any] | None]
|
Tuple of (transformed_data, state, metadata): - transformed_data contains:
|
BatchCorrectionConfig¤
diffbio.operators.singlecell.batch_correction.BatchCorrectionConfig
dataclass
¤
BatchCorrectionConfig(
n_clusters: int = 100,
n_features: int = 50,
n_batches: int = 2,
n_iterations: int = 10,
theta: float = 2.0,
sigma: float = 0.1,
temperature: float = 1.0,
)
Bases: OperatorConfig
Configuration for DifferentiableHarmony.
Attributes:
| Name | Type | Description |
|---|---|---|
n_clusters |
int
|
Number of clusters for soft assignment. |
n_features |
int
|
Dimensionality of input embeddings. |
n_batches |
int
|
Number of distinct batches. |
n_iterations |
int
|
Number of correction iterations. |
theta |
float
|
Diversity penalty parameter. |
sigma |
float
|
Soft assignment bandwidth. |
temperature |
float
|
Temperature for softmax operations. |
DifferentiableVelocity¤
diffbio.operators.singlecell.velocity.DifferentiableVelocity
¤
DifferentiableVelocity(
config: VelocityConfig,
*,
rngs: Rngs | None = None,
name: str | None = None,
)
Bases: OperatorModule
Differentiable RNA velocity estimation via Neural ODEs.
This operator estimates RNA velocity from spliced and unspliced counts using learned kinetics parameters and differentiable ODE integration.
Algorithm: 1. Encode expression to latent time per cell 2. Learn per-gene kinetics (alpha, beta, gamma) 3. Compute velocity from splicing ODE: ds/dt = beta * u - gamma * s du/dt = alpha - beta * u 4. Integrate ODE using Euler method
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
VelocityConfig
|
VelocityConfig with model parameters. |
required |
rngs
|
Rngs | None
|
Flax NNX random number generators. |
None
|
name
|
str | None
|
Optional operator name. |
None
|
Example
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
VelocityConfig
|
Velocity configuration. |
required |
rngs
|
Rngs | None
|
Random number generators for initialization. |
None
|
name
|
str | None
|
Optional operator name. |
None
|
apply
¤
apply(
data: PyTree,
state: PyTree,
metadata: dict[str, Any] | None,
random_params: Any = None,
stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]
Apply RNA velocity estimation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
PyTree
|
Dictionary containing: - "spliced": Spliced mRNA counts (n_cells, n_genes) - "unspliced": Unspliced mRNA counts (n_cells, n_genes) |
required |
state
|
PyTree
|
Element state (passed through unchanged) |
required |
metadata
|
dict[str, Any] | None
|
Element metadata (passed through unchanged) |
required |
random_params
|
Any
|
Not used |
None
|
stats
|
dict[str, Any] | None
|
Not used |
None
|
Returns:
| Type | Description |
|---|---|
tuple[PyTree, PyTree, dict[str, Any] | None]
|
Tuple of (transformed_data, state, metadata): - transformed_data contains:
|
VelocityConfig¤
diffbio.operators.singlecell.velocity.VelocityConfig
dataclass
¤
VelocityConfig(
n_genes: int = 2000,
hidden_dim: int = 64,
dt: float = 0.1,
n_steps: int = 10,
kinetics_model: str = "standard",
)
Bases: OperatorConfig
Configuration for DifferentiableVelocity.
Attributes:
| Name | Type | Description |
|---|---|---|
n_genes |
int
|
Number of genes. |
hidden_dim |
int
|
Hidden dimension for neural networks. |
dt |
float
|
Time step for ODE integration. |
n_steps |
int
|
Number of integration steps. |
kinetics_model |
str
|
Type of kinetics model ("standard" or "dynamical"). |
DifferentiableAmbientRemoval¤
diffbio.operators.singlecell.ambient_removal.DifferentiableAmbientRemoval
¤
DifferentiableAmbientRemoval(
config: AmbientRemovalConfig,
*,
rngs: Rngs | None = None,
name: str | None = None,
)
Bases: EncoderDecoderOperator
Differentiable ambient RNA removal using VAE.
This operator removes ambient RNA contamination from single-cell count data using a variational autoencoder that models both cell-intrinsic expression and ambient contamination.
Algorithm: 1. Encode counts to latent space + contamination fraction 2. Sample latent (reparameterization trick) 3. Decode to cell-intrinsic expression rate 4. Compute decontaminated counts by subtracting ambient contribution
Inherits from EncoderDecoderOperator to get:
- reparameterize() for sampling with reparameterization trick
- kl_divergence() for KL from standard normal
- elbo_loss() for combining reconstruction and KL losses
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
AmbientRemovalConfig
|
AmbientRemovalConfig with model parameters. |
required |
rngs
|
Rngs | None
|
Flax NNX random number generators. |
None
|
name
|
str | None
|
Optional operator name. |
None
|
Example
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
AmbientRemovalConfig
|
Ambient removal configuration. |
required |
rngs
|
Rngs | None
|
Random number generators for initialization. |
None
|
name
|
str | None
|
Optional operator name. |
None
|
apply
¤
apply(
data: PyTree,
state: PyTree,
metadata: dict[str, Any] | None,
random_params: Any = None,
stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]
Apply ambient RNA removal.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
PyTree
|
Dictionary containing: - "counts": Raw count matrix (n_cells, n_genes) - "ambient_profile": Ambient expression profile (n_genes,) |
required |
state
|
PyTree
|
Element state (passed through unchanged) |
required |
metadata
|
dict[str, Any] | None
|
Element metadata (passed through unchanged) |
required |
random_params
|
Any
|
Random key for stochastic sampling |
None
|
stats
|
dict[str, Any] | None
|
Not used |
None
|
Returns:
| Type | Description |
|---|---|
tuple[PyTree, PyTree, dict[str, Any] | None]
|
Tuple of (transformed_data, state, metadata): - transformed_data contains:
|
AmbientRemovalConfig¤
diffbio.operators.singlecell.ambient_removal.AmbientRemovalConfig
dataclass
¤
AmbientRemovalConfig(
n_genes: int = 2000,
latent_dim: int = 64,
hidden_dims: list[int] = (lambda: [256, 128])(),
ambient_prior: float = 0.01,
temperature: float = 1.0,
)
Bases: OperatorConfig
Configuration for DifferentiableAmbientRemoval.
Attributes:
| Name | Type | Description |
|---|---|---|
n_genes |
int
|
Number of genes in expression profiles. |
latent_dim |
int
|
Dimension of latent space. |
hidden_dims |
list[int]
|
Hidden layer dimensions for encoder/decoder. |
ambient_prior |
float
|
Prior probability of ambient contamination. |
temperature |
float
|
Temperature for softmax operations. |
DifferentiableCellAnnotator¤
diffbio.operators.singlecell.cell_annotation.DifferentiableCellAnnotator
¤
DifferentiableCellAnnotator(
config: CellAnnotatorConfig,
*,
rngs: Rngs | None = None,
name: str | None = None,
)
Bases: CountReconstructionMixin, CountVAEBackboneMixin, EncoderDecoderOperator
Differentiable cell type annotator with three annotation modes.
Modes¤
celltypist (logistic regression on latent): Encode counts to a VAE latent, apply a linear classifier head, softmax.
cellassign (marker-gene likelihood): Given a binary marker matrix M, compute per-type Poisson log-likelihoods with learnable rate parameters, then softmax.
scanvi (semi-supervised VAE with type-conditioned prior):
VAE encoder + classifier head with learnable per-type Gaussian priors
in latent space. The KL divergence uses p(z|y) = N(mu_y, sigma_y)
instead of the standard N(0, I), and is marginalised over predicted
type probabilities for unlabelled cells.
All modes additionally produce a latent representation via a shared VAE encoder.
Inherits from EncoderDecoderOperator to get:
- reparameterize() for the VAE sampling step
- kl_divergence() for KL from standard normal
- elbo_loss() for combining reconstruction and KL losses
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
CellAnnotatorConfig
|
CellAnnotatorConfig with model parameters. |
required |
rngs
|
Rngs | None
|
Flax NNX random number generators. |
None
|
name
|
str | None
|
Optional operator name. |
None
|
Example
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
CellAnnotatorConfig
|
Annotator configuration. |
required |
rngs
|
Rngs | None
|
Random number generators for initialisation and sampling. |
None
|
name
|
str | None
|
Optional operator name. |
None
|
apply
¤
apply(
data: PyTree,
state: PyTree,
metadata: dict[str, Any] | None,
random_params: Any = None,
stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]
Annotate cells with type probabilities.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
PyTree
|
Dictionary containing:
- |
required |
state
|
PyTree
|
Element state (passed through unchanged). |
required |
metadata
|
dict[str, Any] | None
|
Element metadata (passed through unchanged). |
required |
random_params
|
Any
|
Not used. |
None
|
stats
|
dict[str, Any] | None
|
Not used. |
None
|
Returns:
| Type | Description |
|---|---|
PyTree
|
Tuple of (transformed_data, state, metadata) where |
PyTree
|
transformed_data adds:
- |
CellAnnotatorConfig¤
diffbio.operators.singlecell.cell_annotation.CellAnnotatorConfig
dataclass
¤
CellAnnotatorConfig(
annotation_mode: Literal[
"scanvi", "cellassign", "celltypist"
] = "celltypist",
n_cell_types: int = 10,
n_genes: int = 2000,
latent_dim: int = 10,
hidden_dims: list[int] = (lambda: [128, 64])(),
marker_matrix_shape: tuple[int, int] | None = None,
gene_likelihood: Literal["poisson", "zinb"] = "poisson",
)
Bases: OperatorConfig
Configuration for cell type annotation.
Attributes:
| Name | Type | Description |
|---|---|---|
annotation_mode |
Literal['scanvi', 'cellassign', 'celltypist']
|
Annotation strategy to use. |
n_cell_types |
int
|
Number of cell types to classify. |
n_genes |
int
|
Number of input genes. |
latent_dim |
int
|
Latent-space dimensionality for VAE encoder. |
hidden_dims |
list[int]
|
Hidden layer sizes for encoder and decoder. |
marker_matrix_shape |
tuple[int, int] | None
|
Shape (n_types, n_genes) for cellassign mode. |
gene_likelihood |
Literal['poisson', 'zinb']
|
Reconstruction likelihood for scanvi mode.
|
DifferentiableDiffusionImputer¤
diffbio.operators.singlecell.imputation.DifferentiableDiffusionImputer
¤
DifferentiableDiffusionImputer(
config: DiffusionImputerConfig,
*,
rngs: Rngs | None = None,
name: str | None = None,
)
Bases: OperatorModule
Differentiable MAGIC-style diffusion imputation.
Constructs a cell-cell affinity graph using an alpha-decaying kernel,
symmetrizes it, builds a row-stochastic Markov matrix M = D^{-1} A,
and computes M^t via repeated matrix multiplication for imputation.
This avoids eigendecomposition (whose backward pass produces NaN when
eigenvalues are near-degenerate) while remaining fully differentiable.
Algorithm
- Compute pairwise distances between cells
- Build alpha-decay affinity:
K(i,j) = exp(-(d/sigma_i)^decay) - Symmetrize the affinity via fuzzy set union
- Row-normalize to Markov matrix
M = D^{-1} A - Compute
M^tvia repeated matrix multiplication (t iterations) - Impute:
imputed = M^t @ counts
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
DiffusionImputerConfig
|
DiffusionImputerConfig with operator parameters. |
required |
rngs
|
Rngs | None
|
Flax NNX random number generators (not used, kept for API). |
None
|
name
|
str | None
|
Optional operator name. |
None
|
Example
config = DiffusionImputerConfig(n_neighbors=5, diffusion_t=3) imputer = DifferentiableDiffusionImputer(config, rngs=nnx.Rngs(0)) data = {"counts": jnp.ones((100, 2000))} result, state, meta = imputer.apply(data, {}, None) result["imputed_counts"].shape (100, 2000)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
DiffusionImputerConfig
|
Imputation configuration. |
required |
rngs
|
Rngs | None
|
Random number generators (unused, present for API consistency). |
None
|
name
|
str | None
|
Optional operator name. |
None
|
apply
¤
apply(
data: PyTree,
state: PyTree,
metadata: dict[str, Any] | None,
random_params: Any = None,
stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]
Apply diffusion imputation to single-cell count data.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
PyTree
|
Dictionary containing:
- |
required |
state
|
PyTree
|
Element state (passed through unchanged). |
required |
metadata
|
dict[str, Any] | None
|
Element metadata (passed through unchanged). |
required |
random_params
|
Any
|
Not used (deterministic operator). |
None
|
stats
|
dict[str, Any] | None
|
Not used. |
None
|
Returns:
| Type | Description |
|---|---|
tuple[PyTree, PyTree, dict[str, Any] | None]
|
Tuple of (transformed_data, state, metadata): - transformed_data contains:
|
DiffusionImputerConfig¤
diffbio.operators.singlecell.imputation.DiffusionImputerConfig
dataclass
¤
DiffusionImputerConfig(
n_neighbors: int = 5,
diffusion_t: int = 3,
n_pca_components: int = 100,
decay: float = 1.0,
metric: str = "euclidean",
)
Bases: OperatorConfig
Configuration for MAGIC-style diffusion imputation.
Attributes:
| Name | Type | Description |
|---|---|---|
n_neighbors |
int
|
Number of neighbors for local bandwidth estimation. |
diffusion_t |
int
|
Number of diffusion time steps (matrix power). |
n_pca_components |
int
|
Number of PCA components (reserved for future use). |
decay |
float
|
Exponent for the alpha-decaying kernel (MAGIC default is 1). |
metric |
str
|
Distance metric, either |
DifferentiableTransformerDenoiser¤
diffbio.operators.singlecell.imputation.DifferentiableTransformerDenoiser
¤
DifferentiableTransformerDenoiser(
config: TransformerDenoiserConfig,
*,
rngs: Rngs | None = None,
name: str | None = None,
)
Bases: MaskedGeneTransformerOperatorMixin, OperatorModule
Transformer-based gene denoiser for single-cell expression data.
Genes are treated as tokens in a sequence. For each cell the operator:
- Randomly masks
mask_ratiofraction of genes (sets expression to 0). - Projects gene IDs into embeddings via
TransformerSequenceEncoder(token_embedding mode) and adds a learned projection of the expression value so the transformer receives both identity and magnitude. - Passes the sequence through a transformer encoder to obtain contextualised gene representations.
- Predicts masked gene expression from context via a linear output head.
- Returns imputed counts where masked positions are replaced with predictions and unmasked positions are kept from the original input.
Each cell is processed independently via jax.vmap over the cell
dimension.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
TransformerDenoiserConfig
|
TransformerDenoiserConfig with operator parameters. |
required |
rngs
|
Rngs | None
|
Flax NNX random number generators. |
None
|
name
|
str | None
|
Optional operator name. |
None
|
Example
config = TransformerDenoiserConfig(n_genes=2000, hidden_dim=128) denoiser = DifferentiableTransformerDenoiser( ... config, rngs=nnx.Rngs(params=0, sample=1, dropout=2)) rp = denoiser.generate_random_params( ... jax.random.key(0), {"counts": (100, 2000)}) data = {"counts": counts, "gene_ids": jnp.arange(2000)} result, state, meta = denoiser.apply(data, {}, None, random_params=rp) result["imputed_counts"].shape (100, 2000)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
TransformerDenoiserConfig
|
Denoiser configuration. |
required |
rngs
|
Rngs | None
|
Random number generators for parameter initialisation. |
None
|
name
|
str | None
|
Optional operator name. |
None
|
apply
¤
apply(
data: PyTree,
state: PyTree,
metadata: dict[str, Any] | None,
random_params: Any = None,
stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]
Apply transformer denoising to single-cell count data.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
PyTree
|
Dictionary containing:
- |
required |
state
|
PyTree
|
Element state (passed through unchanged). |
required |
metadata
|
dict[str, Any] | None
|
Element metadata (passed through unchanged). |
required |
random_params
|
Any
|
JAX random key for mask generation. |
None
|
stats
|
dict[str, Any] | None
|
Not used. |
None
|
Returns:
| Type | Description |
|---|---|
tuple[PyTree, PyTree, dict[str, Any] | None]
|
Tuple of (transformed_data, state, metadata): - transformed_data contains:
|
TransformerDenoiserConfig¤
diffbio.operators.singlecell.imputation.TransformerDenoiserConfig
dataclass
¤
TransformerDenoiserConfig(
n_genes: int = 2000,
hidden_dim: int = 128,
num_layers: int = 2,
num_heads: int = 4,
mask_ratio: float = 0.15,
dropout_rate: float = 0.1,
)
Bases: MaskedGeneTransformerConfigBase
Configuration for transformer-based gene denoising.
The denoiser treats genes as tokens: each gene has an expression value and a gene ID. A random fraction of genes is masked (expression zeroed) and the transformer predicts the original expression from the unmasked context.
__post_init__
¤
Default masked-gene operators to sampled stochastic execution.
DifferentiableDoubletScorer¤
diffbio.operators.singlecell.doublet_detection.DifferentiableDoubletScorer
¤
DifferentiableDoubletScorer(
config: DoubletScorerConfig,
*,
rngs: Rngs | None = None,
name: str | None = None,
)
Bases: OperatorModule
Differentiable Scrublet-style doublet detection operator.
Detects doublets by generating synthetic doublet profiles from random cell pairs, embedding real and synthetic cells into PCA space, and scoring each real cell via the Bayesian k-NN likelihood ratio from Scrublet (Wolock et al., 2019).
Algorithm
- Generate
n_cells * sim_doublet_ratiosynthetic doublets - Concatenate real and synthetic cells
- PCA-embed via truncated SVD
- Compute pairwise distances in PCA space
- Adjust k upward:
k_adj = round(k * (1 + n_syn / n_cells)) - Count soft synthetic neighbors in each real cell's k-NN
- Compute Laplace-smoothed fraction
qof synthetic neighbors - Bayesian likelihood ratio:
Ld = q * rho / r / denom - Apply sigmoid threshold for predicted doublet calls
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
DoubletScorerConfig
|
DoubletScorerConfig with operator parameters. |
required |
rngs
|
Rngs | None
|
Flax NNX random number generators. |
None
|
name
|
str | None
|
Optional operator name. |
None
|
Example
config = DoubletScorerConfig(n_neighbors=30, n_pca_components=30, ... n_genes=2000) scorer = DifferentiableDoubletScorer(config, rngs=nnx.Rngs(0)) rng = jax.random.key(0) rp = scorer.generate_random_params(rng, {"counts": (500, 2000)}) result, state, meta = scorer.apply({"counts": counts}, {}, None, ... random_params=rp) result["doublet_scores"].shape (500,)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
DoubletScorerConfig
|
Doublet scorer configuration. |
required |
rngs
|
Rngs | None
|
Random number generators for stochastic operations. |
None
|
name
|
str | None
|
Optional operator name. |
None
|
apply
¤
apply(
data: PyTree,
state: PyTree,
metadata: dict[str, Any] | None,
random_params: Any = None,
stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]
Apply doublet detection to single-cell count data.
Implements Scrublet's Bayesian k-NN likelihood-ratio scoring:
- Generate
n_cells * sim_doublet_ratiosynthetic doublets - Adjust k upward:
k_adj = round(k * (1 + n_syn / n_cells)) - For each real cell, count synthetic neighbors in its k-NN
- Compute Laplace-smoothed fraction
q = (syn_count + 1) / (k_adj + 2) - Bayesian likelihood ratio:
Ld = q * rho / r / (1 - rho - q*(1 - rho - rho/r))
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
PyTree
|
Dictionary containing:
- |
required |
state
|
PyTree
|
Element state (passed through unchanged). |
required |
metadata
|
dict[str, Any] | None
|
Element metadata (passed through unchanged). |
required |
random_params
|
Any
|
JAX random key for synthetic doublet generation. |
None
|
stats
|
dict[str, Any] | None
|
Not used. |
None
|
Returns:
| Type | Description |
|---|---|
tuple[PyTree, PyTree, dict[str, Any] | None]
|
Tuple of (transformed_data, state, metadata): - transformed_data contains:
|
DoubletScorerConfig¤
diffbio.operators.singlecell.doublet_detection.DoubletScorerConfig
dataclass
¤
DoubletScorerConfig(
n_neighbors: int = 30,
expected_doublet_rate: float = 0.06,
sim_doublet_ratio: float = 2.0,
n_pca_components: int = 30,
n_genes: int = 2000,
threshold_temperature: float = 10.0,
)
Bases: OperatorConfig
Configuration for Scrublet-style doublet detection.
Attributes:
| Name | Type | Description |
|---|---|---|
n_neighbors |
int
|
Base number of nearest neighbors for scoring (adjusted upward to account for synthetic pool size). |
expected_doublet_rate |
float
|
Prior expected fraction of doublets (rho in the Bayesian likelihood ratio). |
sim_doublet_ratio |
float
|
Ratio of synthetic doublets to real cells. Scrublet default is 2.0, meaning 2x as many synthetics as real cells. |
n_pca_components |
int
|
Number of PCA components for embedding. |
n_genes |
int
|
Number of genes in expression profiles. |
threshold_temperature |
float
|
Temperature for sigmoid doublet thresholding. |
DifferentiableSoloDetector¤
diffbio.operators.singlecell.doublet_detection.DifferentiableSoloDetector
¤
DifferentiableSoloDetector(
config: SoloDetectorConfig,
*,
rngs: Rngs | None = None,
name: str | None = None,
)
Bases: CountVAEBackboneMixin, EncoderDecoderOperator
Solo-style VAE doublet detector.
Detects doublets by encoding cells through a VAE, generating synthetic doublets in count space, then classifying real vs synthetic cells in the VAE latent space.
Algorithm
- Generate synthetic doublets by summing random cell pairs
- Concatenate real and synthetic counts
- Encode all cells through the VAE encoder to obtain (mean, logvar)
- Sample latent z via the reparameterization trick
- Run a binary classifier on real-cell latents
- Return doublet probabilities, labels, and latent representations
Architecture
- Encoder: counts -> log1p -> hidden layers (ReLU) -> (mean, logvar)
- Decoder: z -> hidden layers (ReLU) -> log_rate
- Classifier: z -> Linear -> ReLU -> Linear -> sigmoid
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
SoloDetectorConfig
|
SoloDetectorConfig with model parameters. |
required |
rngs
|
Rngs | None
|
Flax NNX random number generators. |
None
|
name
|
str | None
|
Optional operator name. |
None
|
Example
config = SoloDetectorConfig(n_genes=2000, latent_dim=10) detector = DifferentiableSoloDetector(config, rngs=nnx.Rngs(42)) rng = jax.random.key(0) rp = detector.generate_random_params(rng, {"counts": (500, 2000)}) result, _, _ = detector.apply({"counts": counts}, {}, None, random_params=rp) result["doublet_probabilities"].shape (500,)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
SoloDetectorConfig
|
Solo detector configuration. |
required |
rngs
|
Rngs | None
|
Random number generators for initialization and sampling. |
None
|
name
|
str | None
|
Optional operator name. |
None
|
apply
¤
apply(
data: PyTree,
state: PyTree,
metadata: dict[str, Any] | None,
random_params: Any = None,
stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]
Apply Solo-style VAE doublet detection.
Steps
- Generate synthetic doublets from random cell pairs
- Concatenate real and synthetic counts
- Encode all cells to latent space (mean, logvar)
- Sample z via reparameterization trick
- Run classifier on real-cell latents
- Return probabilities, labels, and latent for real cells only
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
PyTree
|
Dictionary containing:
- |
required |
state
|
PyTree
|
Element state (passed through unchanged). |
required |
metadata
|
dict[str, Any] | None
|
Element metadata (passed through unchanged). |
required |
random_params
|
Any
|
JAX random key for synthetic doublet generation. |
None
|
stats
|
dict[str, Any] | None
|
Not used. |
None
|
Returns:
| Type | Description |
|---|---|
tuple[PyTree, PyTree, dict[str, Any] | None]
|
Tuple of (transformed_data, state, metadata): - transformed_data contains:
|
SoloDetectorConfig¤
diffbio.operators.singlecell.doublet_detection.SoloDetectorConfig
dataclass
¤
SoloDetectorConfig(
n_genes: int = 2000,
latent_dim: int = 10,
hidden_dims: list[int] = (lambda: [128, 64])(),
classifier_hidden_dim: int = 64,
sim_doublet_ratio: float = 2.0,
)
Bases: OperatorConfig
Configuration for Solo-style VAE doublet detection.
Solo (Bernstein et al., Cell Systems 2020) trains a VAE on real cells, generates synthetic doublets, encodes both into latent space, and trains a binary classifier to distinguish singlets from doublets.
Attributes:
| Name | Type | Description |
|---|---|---|
n_genes |
int
|
Number of genes in expression profiles. |
latent_dim |
int
|
Dimension of the VAE latent space. |
hidden_dims |
list[int]
|
Hidden layer dimensions for encoder/decoder. |
classifier_hidden_dim |
int
|
Hidden dimension for the latent-space classifier. |
sim_doublet_ratio |
float
|
Ratio of synthetic doublets to real cells. |
DifferentiableCellCommunication¤
diffbio.operators.singlecell.communication.DifferentiableCellCommunication
¤
DifferentiableCellCommunication(
config: CellCommunicationConfig,
*,
rngs: Rngs | None = None,
name: str | None = None,
)
Bases: GraphOperator
GNN-based differentiable cell-cell communication analysis.
Analyses inter-cellular signaling by applying GATv2 graph attention on a spatial cell graph whose edges carry ligand-receptor expression features.
Algorithm
- Build per-edge L-R expression features from
countsandlr_pairs: for each edge (i, j) the feature vector is the concatenation of[L_expr[source], R_expr[target]]across all L-R pairs, projected toedge_features_dim. - Project per-node gene expression to initial node embeddings.
- Apply stacked GATv2 layers (
SpatialAttentionGNN) for message passing on the spatial cell graph. - Decode node embeddings into per-node pathway activity and per-node
communication scores via
SignalingDecoder.
Inherits from GraphOperator to get:
- scatter_aggregate() for message aggregation
- global_pool() for graph-level pooling
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
CellCommunicationConfig
|
CellCommunicationConfig with model parameters. |
required |
rngs
|
Rngs | None
|
Flax NNX random number generators. |
None
|
name
|
str | None
|
Optional operator name. |
None
|
Example
config = CellCommunicationConfig(n_genes=50, n_lr_pairs=3, hidden_dim=32) op = DifferentiableCellCommunication(config, rngs=nnx.Rngs(0)) data = {"counts": counts, "spatial_graph": graph, "lr_pairs": pairs} result, state, meta = op.apply(data, {}, None) result["communication_scores"].shape (n_cells, 3)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
CellCommunicationConfig
|
Cell communication configuration. |
required |
rngs
|
Rngs | None
|
Random number generators for parameter initialization. |
None
|
name
|
str | None
|
Optional operator name. |
None
|
apply
¤
apply(
data: PyTree,
state: PyTree,
metadata: dict[str, Any] | None,
random_params: Any = None,
stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]
Apply GNN-based cell-cell communication analysis.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
PyTree
|
Dictionary containing:
- |
required |
state
|
PyTree
|
Element state (passed through unchanged). |
required |
metadata
|
dict[str, Any] | None
|
Element metadata (passed through unchanged). |
required |
random_params
|
Any
|
Not used (non-stochastic operator). |
None
|
stats
|
dict[str, Any] | None
|
Not used. |
None
|
Returns:
| Type | Description |
|---|---|
tuple[PyTree, PyTree, dict[str, Any] | None]
|
Tuple of (transformed_data, state, metadata): - transformed_data contains all original keys plus:
|
CellCommunicationConfig¤
diffbio.operators.singlecell.communication.CellCommunicationConfig
dataclass
¤
CellCommunicationConfig(
hidden_dim: int = 64,
num_heads: int = 4,
num_gnn_layers: int = 2,
n_pathways: int = 20,
dropout_rate: float = 0.1,
n_genes: int = 2000,
n_lr_pairs: int = 10,
edge_features_dim: int = 8,
)
Bases: _CellCommunicationInputConfig, _CellCommunicationModelConfig, OperatorConfig
Configuration for GNN-based cell-cell communication analysis.
DifferentiableLigandReceptor¤
diffbio.operators.singlecell.communication.DifferentiableLigandReceptor
¤
DifferentiableLigandReceptor(
config: LRScoringConfig,
*,
rngs: Rngs | None = None,
name: str | None = None,
)
Bases: TemperatureOperator
Differentiable ligand-receptor co-expression scoring operator.
Scores cell-cell communication by computing adjacency-weighted co-expression of ligand-receptor gene pairs. For each pair, the score at each receiver cell is the sum of sender ligand expression times receiver receptor expression, weighted by a fuzzy k-NN adjacency graph.
Algorithm
- Build a symmetric fuzzy k-NN adjacency from the count matrix using
compute_pairwise_distances,compute_fuzzy_membership, andsymmetrize_graph. - For each L-R pair (ligand_idx, receptor_idx):
score_i = sum_j(adjacency[i,j] * L[j] * R[i])where L[j] is the sender's ligand expression and R[i] is the receiver's receptor expression.- Compute analytical soft p-values via z-score comparison against an expected null distribution.
Inherits from TemperatureOperator to get:
- _temperature property for temperature-controlled smoothing
- soft_max() for logsumexp-based smooth maximum
- soft_argmax() for soft position selection
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
LRScoringConfig
|
LRScoringConfig with operator parameters. |
required |
rngs
|
Rngs | None
|
Flax NNX random number generators. |
None
|
name
|
str | None
|
Optional operator name. |
None
|
Example
config = LRScoringConfig(n_neighbors=15) op = DifferentiableLigandReceptor(config, rngs=nnx.Rngs(0)) data = {"counts": counts, "lr_pairs": jnp.array([[0, 1]])} result, state, meta = op.apply(data, {}, None) result["lr_scores"].shape (n_cells, 1)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
LRScoringConfig
|
L-R scoring configuration. |
required |
rngs
|
Rngs | None
|
Random number generators for parameter initialization. |
None
|
name
|
str | None
|
Optional operator name. |
None
|
apply
¤
apply(
data: PyTree,
state: PyTree,
metadata: dict[str, Any] | None,
random_params: Any = None,
stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]
Apply ligand-receptor co-expression scoring.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
PyTree
|
Dictionary containing:
- |
required |
state
|
PyTree
|
Element state (passed through unchanged). |
required |
metadata
|
dict[str, Any] | None
|
Element metadata (passed through unchanged). |
required |
random_params
|
Any
|
Not used (non-stochastic operator). |
None
|
stats
|
dict[str, Any] | None
|
Not used. |
None
|
Returns:
| Type | Description |
|---|---|
tuple[PyTree, PyTree, dict[str, Any] | None]
|
Tuple of (transformed_data, state, metadata): - transformed_data contains:
|
LRScoringConfig¤
diffbio.operators.singlecell.communication.LRScoringConfig
dataclass
¤
LRScoringConfig(
n_neighbors: int = 15,
temperature: float = 1.0,
learnable_temperature: bool = False,
metric: str = "euclidean",
kh: float = 0.5,
hill_n: float = 1.0,
)
Bases: OperatorConfig
Configuration for ligand-receptor co-expression scoring.
Attributes:
| Name | Type | Description |
|---|---|---|
n_neighbors |
int
|
Number of nearest neighbors for k-NN graph. |
temperature |
float
|
Temperature for soft p-value sigmoid. |
learnable_temperature |
bool
|
Whether the temperature is a learnable parameter. |
metric |
str
|
Distance metric for k-NN, either |
kh |
float
|
Hill function half-maximal constant (CellChat default 0.5). |
hill_n |
float
|
Hill function cooperativity coefficient (CellChat default 1.0). |
DifferentiableGRN¤
diffbio.operators.singlecell.grn_inference.DifferentiableGRN
¤
DifferentiableGRN(
config: GRNInferenceConfig,
*,
rngs: Rngs | None = None,
name: str | None = None,
)
Bases: OperatorModule
Differentiable gene regulatory network inference operator.
Uses GATv2 graph attention on a TF-gene bipartite graph to infer regulatory strengths. Each TF is connected to every gene; the attention weight on each edge represents how strongly the TF regulates that gene.
This is a novel differentiable alternative to GENIE3's random forest feature importance scoring. The key insight is that in GENIE3, each gene's expression is predicted from TF expression, and feature importance measures regulatory strength. Here, GATv2 attention performs an analogous role: TF nodes attend to gene nodes, and the learned attention weights capture regulatory relationships.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
GRNInferenceConfig
|
GRNInferenceConfig with model parameters. |
required |
rngs
|
Rngs | None
|
Flax NNX random number generators. |
None
|
name
|
str | None
|
Optional operator name. |
None
|
Example
config = GRNInferenceConfig(n_tfs=5, n_genes=20, hidden_dim=16) op = DifferentiableGRN(config, rngs=nnx.Rngs(0)) data = {"counts": counts, "tf_indices": jnp.arange(5)} result, state, meta = op.apply(data, {}, None) result["grn_matrix"].shape (5, 20)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
GRNInferenceConfig
|
GRN inference configuration. |
required |
rngs
|
Rngs | None
|
Random number generators for parameter initialization. |
None
|
name
|
str | None
|
Optional operator name. |
None
|
apply
¤
apply(
data: PyTree,
state: PyTree,
metadata: dict[str, Any] | None,
random_params: Any = None,
stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]
Apply differentiable GRN inference.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
PyTree
|
Dictionary containing:
- |
required |
state
|
PyTree
|
Element state (passed through unchanged). |
required |
metadata
|
dict[str, Any] | None
|
Element metadata (passed through unchanged). |
required |
random_params
|
Any
|
Not used (non-stochastic operator). |
None
|
stats
|
dict[str, Any] | None
|
Not used. |
None
|
Returns:
| Type | Description |
|---|---|
tuple[PyTree, PyTree, dict[str, Any] | None]
|
Tuple of (transformed_data, state, metadata): - transformed_data contains all original keys plus:
|
GRNInferenceConfig¤
diffbio.operators.singlecell.grn_inference.GRNInferenceConfig
dataclass
¤
GRNInferenceConfig(
n_tfs: int = 50,
n_genes: int = 2000,
hidden_dim: int = 64,
num_heads: int = 4,
sparsity_temperature: float = 0.1,
sparsity_lambda: float = 0.01,
)
Bases: OperatorConfig
Configuration for differentiable GRN inference.
Attributes:
| Name | Type | Description |
|---|---|---|
n_tfs |
int
|
Number of transcription factors. |
n_genes |
int
|
Number of genes in the expression matrix. |
hidden_dim |
int
|
Hidden dimension for GATv2 attention (must be divisible by num_heads). |
num_heads |
int
|
Number of attention heads in the GATv2 layer. |
sparsity_temperature |
float
|
Temperature for soft L1 sparsity gating. Lower values produce sharper thresholding toward zero. |
sparsity_lambda |
float
|
L1 regularization weight (used by downstream loss functions, not directly by the operator). |
DifferentiableMMDBatchCorrection¤
diffbio.operators.singlecell.enhanced_batch_correction.DifferentiableMMDBatchCorrection
¤
DifferentiableMMDBatchCorrection(
config: MMDBatchCorrectionConfig,
*,
rngs: Rngs | None = None,
name: str | None = None,
)
Bases: LossBalancingMixin, OperatorModule
Autoencoder batch correction with MMD regularisation.
Architecture
Encoder MLP maps gene expression to a latent representation, and a decoder MLP reconstructs the expression from that latent. The MMD loss penalises distributional mismatch between batches in latent space so the learned representation becomes batch-invariant.
Loss
reconstruction_mse + mmd(latent_batch_0, latent_batch_1, ...)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
MMDBatchCorrectionConfig
|
MMDBatchCorrectionConfig with model hyper-parameters. |
required |
rngs
|
Rngs | None
|
Flax NNX random number generators. |
None
|
name
|
str | None
|
Optional operator name. |
None
|
Example
config = MMDBatchCorrectionConfig(n_genes=2000) op = DifferentiableMMDBatchCorrection(config, rngs=nnx.Rngs(0)) result, _, _ = op.apply(data, {}, None)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
MMDBatchCorrectionConfig
|
Operator configuration. |
required |
rngs
|
Rngs | None
|
Random number generators for weight initialisation. |
None
|
name
|
str | None
|
Optional operator name. |
None
|
apply
¤
apply(
data: PyTree,
state: PyTree,
metadata: dict[str, Any] | None,
random_params: Any = None,
stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]
Encode, decode, and compute MMD + reconstruction losses.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
PyTree
|
Dictionary containing:
- |
required |
state
|
PyTree
|
Pipeline state (passed through unchanged). |
required |
metadata
|
dict[str, Any] | None
|
Pipeline metadata (passed through unchanged). |
required |
random_params
|
Any
|
Unused. |
None
|
stats
|
dict[str, Any] | None
|
Unused. |
None
|
Returns:
| Type | Description |
|---|---|
tuple[PyTree, PyTree, dict[str, Any] | None]
|
Tuple of |
MMDBatchCorrectionConfig¤
diffbio.operators.singlecell.enhanced_batch_correction.MMDBatchCorrectionConfig
dataclass
¤
MMDBatchCorrectionConfig(
n_genes: int = 2000,
hidden_dim: int = 128,
latent_dim: int = 64,
kernel_bandwidth: float = 1.0,
use_gradnorm: bool = False,
)
Bases: OperatorConfig
Configuration for MMD-based batch correction.
Attributes:
| Name | Type | Description |
|---|---|---|
n_genes |
int
|
Number of input genes (features). |
hidden_dim |
int
|
Width of hidden layers in the autoencoder. |
latent_dim |
int
|
Dimensionality of the latent space. |
kernel_bandwidth |
float
|
Bandwidth for the RBF kernel in the MMD loss. |
use_gradnorm |
bool
|
Whether to use GradNormBalancer for multi-task loss balancing between reconstruction and MMD losses. |
DifferentiableWGANBatchCorrection¤
diffbio.operators.singlecell.enhanced_batch_correction.DifferentiableWGANBatchCorrection
¤
DifferentiableWGANBatchCorrection(
config: WGANBatchCorrectionConfig,
*,
rngs: Rngs | None = None,
name: str | None = None,
)
Bases: LossBalancingMixin, OperatorModule
Adversarial autoencoder batch correction with Wasserstein GAN loss.
Architecture
An encoder (generator) maps gene expression to a batch-invariant latent space and a decoder reconstructs the expression. A separate discriminator tries to predict the batch label from the latent representation. Gradient reversal (a la scGPT) ensures the encoder learns to fool the discriminator, yielding batch-invariant latents.
Losses
generator_loss: Wasserstein generator loss (encoder wants to fool the discriminator) plus reconstruction MSE.discriminator_loss: Wasserstein discriminator/critic loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
WGANBatchCorrectionConfig
|
WGANBatchCorrectionConfig with model hyper-parameters. |
required |
rngs
|
Rngs | None
|
Flax NNX random number generators. |
None
|
name
|
str | None
|
Optional operator name. |
None
|
Example
config = WGANBatchCorrectionConfig(n_genes=2000) op = DifferentiableWGANBatchCorrection(config, rngs=nnx.Rngs(0)) result, _, _ = op.apply(data, {}, None)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
WGANBatchCorrectionConfig
|
Operator configuration. |
required |
rngs
|
Rngs | None
|
Random number generators for weight initialisation. |
None
|
name
|
str | None
|
Optional operator name. |
None
|
apply
¤
apply(
data: PyTree,
state: PyTree,
metadata: dict[str, Any] | None,
random_params: Any = None,
stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]
Encode, decode, and compute adversarial + reconstruction losses.
The discriminator receives latent codes through a gradient reversal layer so that encoder gradients push toward batch invariance while discriminator gradients push toward better batch classification.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
PyTree
|
Dictionary containing:
- |
required |
state
|
PyTree
|
Pipeline state (passed through unchanged). |
required |
metadata
|
dict[str, Any] | None
|
Pipeline metadata (passed through unchanged). |
required |
random_params
|
Any
|
Unused. |
None
|
stats
|
dict[str, Any] | None
|
Unused. |
None
|
Returns:
| Type | Description |
|---|---|
tuple[PyTree, PyTree, dict[str, Any] | None]
|
Tuple of |
WGANBatchCorrectionConfig¤
diffbio.operators.singlecell.enhanced_batch_correction.WGANBatchCorrectionConfig
dataclass
¤
WGANBatchCorrectionConfig(
n_genes: int = 2000,
hidden_dim: int = 128,
latent_dim: int = 64,
discriminator_hidden_dim: int = 64,
use_gradnorm: bool = False,
)
Bases: OperatorConfig
Configuration for WGAN-based batch correction.
Attributes:
| Name | Type | Description |
|---|---|---|
n_genes |
int
|
Number of input genes (features). |
hidden_dim |
int
|
Width of hidden layers in the generator autoencoder. |
latent_dim |
int
|
Dimensionality of the latent space. |
discriminator_hidden_dim |
int
|
Width of hidden layers in the discriminator. |
use_gradnorm |
bool
|
Whether to use GradNormBalancer for multi-task loss balancing between generator and discriminator losses. |
DifferentiableSwitchDE¤
diffbio.operators.singlecell.switch_de.DifferentiableSwitchDE
¤
DifferentiableSwitchDE(
config: SwitchDEConfig,
*,
rngs: Rngs | None = None,
name: str | None = None,
)
Bases: TemperatureOperator
Differentiable sigmoidal switch model for differential expression.
Models gene expression as a sigmoidal function of pseudotime:
g(t) = amplitude * sigmoid((t - t_switch) / temperature) + baseline
Each gene has learnable parameters for switch time, amplitude, and baseline expression level. The switch score quantifies how strongly a gene switches, computed as the maximum sigmoid derivative scaled by amplitude.
Inherits from TemperatureOperator to get:
- _temperature property for temperature-controlled smoothing
- soft_max() for logsumexp-based smooth maximum
- soft_argmax() for soft position selection
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
SwitchDEConfig
|
SwitchDEConfig with model parameters. |
required |
rngs
|
Rngs | None
|
Flax NNX random number generators. |
None
|
name
|
str | None
|
Optional operator name. |
None
|
Example
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
SwitchDEConfig
|
Switch DE configuration. |
required |
rngs
|
Rngs | None
|
Random number generators for initialization. |
None
|
name
|
str | None
|
Optional operator name. |
None
|
apply
¤
apply(
data: PyTree,
state: PyTree,
metadata: dict[str, Any] | None,
random_params: Any = None,
stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]
Apply sigmoidal switch DE model to single-cell data.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
PyTree
|
Dictionary containing: - "counts": Gene expression counts (n_cells, n_genes) - "pseudotime": Pseudotime values per cell (n_cells,) |
required |
state
|
PyTree
|
Element state (passed through unchanged). |
required |
metadata
|
dict[str, Any] | None
|
Element metadata (passed through unchanged). |
required |
random_params
|
Any
|
Not used. |
None
|
stats
|
dict[str, Any] | None
|
Not used. |
None
|
Returns:
| Type | Description |
|---|---|
tuple[PyTree, PyTree, dict[str, Any] | None]
|
Tuple of (transformed_data, state, metadata): - transformed_data contains:
|
SwitchDEConfig¤
diffbio.operators.singlecell.switch_de.SwitchDEConfig
dataclass
¤
SwitchDEConfig(
n_genes: int = 2000,
temperature: float = 1.0,
learnable_temperature: bool = False,
)
Bases: OperatorConfig
Configuration for sigmoidal switch differential expression.
Attributes:
| Name | Type | Description |
|---|---|---|
n_genes |
int
|
Number of genes to model. |
temperature |
float
|
Temperature controlling sigmoid smoothness. Lower values produce sharper switch transitions. |
learnable_temperature |
bool
|
Whether temperature is a learnable parameter. |
DifferentiablePseudotime¤
diffbio.operators.singlecell.trajectory.DifferentiablePseudotime
¤
DifferentiablePseudotime(
config: PseudotimeConfig,
*,
rngs: Rngs | None = None,
name: str | None = None,
)
Bases: OperatorModule
Differentiable pseudotime computation via diffusion maps.
Constructs a k-NN affinity graph, builds a Markov transition matrix, and computes diffusion components through subspace iteration with QR orthogonalization. Pseudotime is defined as the Euclidean distance in diffusion-component space from the designated root cell.
Algorithm
- Compute pairwise distances between cells.
- Compute fuzzy membership with local bandwidth (k-th neighbor).
- Symmetrize the graph via fuzzy set union.
- Row-normalize to obtain a Markov transition matrix.
- Extract the top
n_diffusion_componentseigenvectors of the symmetrized transition matrix via subspace iteration (repeated matmul + QR), excluding the trivial eigenvalue 1. - Weight eigenvectors by their Rayleigh-quotient eigenvalues to form diffusion components.
- Pseudotime = L2 distance from root cell in diffusion-component space.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
PseudotimeConfig
|
PseudotimeConfig with operator parameters. |
required |
rngs
|
Rngs | None
|
Flax NNX random number generators (unused, kept for API). |
None
|
name
|
str | None
|
Optional operator name. |
None
|
Example
config = PseudotimeConfig(n_neighbors=15, n_diffusion_components=10) op = DifferentiablePseudotime(config) data = {"embeddings": jnp.ones((50, 20))} result, state, meta = op.apply(data, {}, None) result["pseudotime"].shape (50,)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
PseudotimeConfig
|
Pseudotime configuration. |
required |
rngs
|
Rngs | None
|
Random number generators (unused, present for API consistency). |
None
|
name
|
str | None
|
Optional operator name. |
None
|
apply
¤
apply(
data: PyTree,
state: PyTree,
metadata: dict[str, Any] | None,
random_params: Any = None,
stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]
Apply pseudotime computation to cell embeddings.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
PyTree
|
Dictionary containing:
- |
required |
state
|
PyTree
|
Element state (passed through unchanged). |
required |
metadata
|
dict[str, Any] | None
|
Element metadata (passed through unchanged). |
required |
random_params
|
Any
|
Not used (deterministic operator). |
None
|
stats
|
dict[str, Any] | None
|
Not used. |
None
|
Returns:
| Type | Description |
|---|---|
tuple[PyTree, PyTree, dict[str, Any] | None]
|
Tuple of (transformed_data, state, metadata): - transformed_data contains:
|
PseudotimeConfig¤
diffbio.operators.singlecell.trajectory.PseudotimeConfig
dataclass
¤
PseudotimeConfig(
n_neighbors: int = 15,
n_diffusion_components: int = 10,
root_cell_index: int = 0,
metric: str = "euclidean",
)
Bases: OperatorConfig
Configuration for pseudotime computation.
Attributes:
| Name | Type | Description |
|---|---|---|
n_neighbors |
int
|
Number of neighbors for k-NN graph construction. |
n_diffusion_components |
int
|
Number of diffusion map components to retain. |
root_cell_index |
int
|
Index of the root cell (pseudotime origin). |
metric |
str
|
Distance metric, |
DifferentiableFateProbability¤
diffbio.operators.singlecell.trajectory.DifferentiableFateProbability
¤
DifferentiableFateProbability(
config: FateProbabilityConfig,
*,
rngs: Rngs | None = None,
name: str | None = None,
)
Bases: OperatorModule
Differentiable fate probability estimation via absorption probabilities.
Given a Markov transition matrix and a set of terminal (absorbing) state indices, partitions cells into transient and absorbing sets and computes the probability that each transient cell will eventually reach each absorbing state.
Algorithm
- Partition states into transient (T) and absorbing (A).
- Extract sub-matrices Q = transition[T, T] and R = transition[T, A].
- Solve
(I - Q) @ B = Rfor B (absorption probabilities). - Assign probability 1 to each absorbing state for itself.
The linear solve jnp.linalg.solve is fully differentiable.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
FateProbabilityConfig
|
FateProbabilityConfig with operator parameters. |
required |
rngs
|
Rngs | None
|
Flax NNX random number generators (unused, kept for API). |
None
|
name
|
str | None
|
Optional operator name. |
None
|
Example
config = FateProbabilityConfig(n_macrostates=2) op = DifferentiableFateProbability(config) data = {"transition_matrix": T, "terminal_states": jnp.array([18, 19])} result, state, meta = op.apply(data, {}, None) result["fate_probabilities"].shape (20, 2)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
FateProbabilityConfig
|
Fate probability configuration. |
required |
rngs
|
Rngs | None
|
Random number generators (unused, present for API consistency). |
None
|
name
|
str | None
|
Optional operator name. |
None
|
apply
¤
apply(
data: PyTree,
state: PyTree,
metadata: dict[str, Any] | None,
random_params: Any = None,
stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]
Apply fate probability estimation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
PyTree
|
Dictionary containing:
- |
required |
state
|
PyTree
|
Element state (passed through unchanged). |
required |
metadata
|
dict[str, Any] | None
|
Element metadata (passed through unchanged). |
required |
random_params
|
Any
|
Not used (deterministic operator). |
None
|
stats
|
dict[str, Any] | None
|
Not used. |
None
|
Returns:
| Type | Description |
|---|---|
tuple[PyTree, PyTree, dict[str, Any] | None]
|
Tuple of (transformed_data, state, metadata): - transformed_data contains:
|
FateProbabilityConfig¤
diffbio.operators.singlecell.trajectory.FateProbabilityConfig
dataclass
¤
Bases: OperatorConfig
Configuration for fate probability computation.
Attributes:
| Name | Type | Description |
|---|---|---|
n_macrostates |
int
|
Number of macrostates (terminal fates). |
DifferentiableSpatialDomain¤
diffbio.operators.singlecell.spatial_domains.DifferentiableSpatialDomain
¤
DifferentiableSpatialDomain(
config: SpatialDomainConfig,
*,
rngs: Rngs | None = None,
name: str | None = None,
)
Bases: GraphOperator
STAGATE-inspired differentiable spatial domain identification.
Identifies spatial domains by combining gene expression with spatial coordinates through a graph attention autoencoder. The encoder uses dual-graph GATv2 attention (full + pruned k-NN graphs), and soft domain assignments are computed via learned prototypes with softmax.
Algorithm
- Build spatial k-NN graph from coordinates (full + pruned/mutual).
- Apply GATv2 encoder: counts -> spatial embeddings (dual-graph attention weighted by alpha).
- Decoder: reconstruct gene expression from embeddings (autoencoder).
- Soft domain assignment via softmax on learned domain prototypes.
Inherits from GraphOperator to get:
- scatter_aggregate() for message aggregation
- global_pool() for graph-level pooling
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
SpatialDomainConfig
|
SpatialDomainConfig with model parameters. |
required |
rngs
|
Rngs | None
|
Flax NNX random number generators. |
None
|
name
|
str | None
|
Optional operator name. |
None
|
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
SpatialDomainConfig
|
Spatial domain configuration. |
required |
rngs
|
Rngs | None
|
Random number generators for parameter initialization. |
None
|
name
|
str | None
|
Optional operator name. |
None
|
apply
¤
apply(
data: PyTree,
state: PyTree,
metadata: dict[str, Any] | None,
random_params: Any = None,
stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]
Apply spatial domain identification to spatial transcriptomics data.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
PyTree
|
Dictionary containing:
- |
required |
state
|
PyTree
|
Element state (passed through unchanged). |
required |
metadata
|
dict[str, Any] | None
|
Element metadata (passed through unchanged). |
required |
random_params
|
Any
|
Not used (non-stochastic operator). |
None
|
stats
|
dict[str, Any] | None
|
Not used. |
None
|
Returns:
| Type | Description |
|---|---|
tuple[PyTree, PyTree, dict[str, Any] | None]
|
Tuple of (transformed_data, state, metadata): - transformed_data contains all original keys plus:
|
SpatialDomainConfig¤
diffbio.operators.singlecell.spatial_domains.SpatialDomainConfig
dataclass
¤
SpatialDomainConfig(
n_genes: int = 2000,
hidden_dim: int = 64,
num_heads: int = 4,
n_domains: int = 7,
alpha: float = 0.8,
n_neighbors: int = 15,
)
Bases: OperatorConfig
Configuration for STAGATE-style spatial domain identification.
Attributes:
| Name | Type | Description |
|---|---|---|
n_genes |
int
|
Number of input genes. |
hidden_dim |
int
|
Latent embedding dimension. Must be divisible by num_heads. |
num_heads |
int
|
Number of GATv2 attention heads. |
n_domains |
int
|
Number of spatial domains to identify. |
alpha |
float
|
Weight for pruned graph in dual-graph attention (STAGATE default 0.8). At alpha=0, only the full k-NN graph is used. At alpha=1, only the pruned (mutual k-NN) graph is used. |
n_neighbors |
int
|
Number of nearest neighbors for spatial k-NN graph. |
DifferentiablePASTEAlignment¤
diffbio.operators.singlecell.spatial_domains.DifferentiablePASTEAlignment
¤
DifferentiablePASTEAlignment(
config: PASTEAlignmentConfig,
*,
rngs: Rngs | None = None,
name: str | None = None,
)
Bases: GraphOperator
PASTE-inspired differentiable spatial transcriptomics slice alignment.
Aligns two spatial transcriptomics slices by computing a fused cost that balances expression dissimilarity with spatial structure (Gromov-Wasserstein) and solving for the optimal transport plan via differentiable Sinkhorn.
Algorithm
- Compute expression dissimilarity between slices (Euclidean distance).
- Compute intra-slice spatial distance matrices.
- Compute Gromov-Wasserstein spatial cost that penalizes distortion of pairwise spatial relationships.
- Fuse costs: alpha * expression_cost + (1 - alpha) * spatial_GW_cost.
- Solve OT via SinkhornLayer for the differentiable transport plan.
- Align slice 2 coordinates using the transport plan.
Inherits from GraphOperator to get:
- scatter_aggregate() for message aggregation
- global_pool() for graph-level pooling
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
PASTEAlignmentConfig
|
PASTEAlignmentConfig with alignment parameters. |
required |
rngs
|
Rngs | None
|
Flax NNX random number generators. |
None
|
name
|
str | None
|
Optional operator name. |
None
|
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
PASTEAlignmentConfig
|
PASTE alignment configuration. |
required |
rngs
|
Rngs | None
|
Random number generators. |
None
|
name
|
str | None
|
Optional operator name. |
None
|
apply
¤
apply(
data: PyTree,
state: PyTree,
metadata: dict[str, Any] | None,
random_params: Any = None,
stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]
Apply PASTE-style alignment between two spatial transcriptomics slices.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
PyTree
|
Dictionary containing:
- |
required |
state
|
PyTree
|
Element state (passed through unchanged). |
required |
metadata
|
dict[str, Any] | None
|
Element metadata (passed through unchanged). |
required |
random_params
|
Any
|
Not used (non-stochastic operator). |
None
|
stats
|
dict[str, Any] | None
|
Not used. |
None
|
Returns:
| Type | Description |
|---|---|
tuple[PyTree, PyTree, dict[str, Any] | None]
|
Tuple of (transformed_data, state, metadata): - transformed_data contains all original keys plus:
|
PASTEAlignmentConfig¤
diffbio.operators.singlecell.spatial_domains.PASTEAlignmentConfig
dataclass
¤
PASTEAlignmentConfig(
alpha: float = 0.1,
sinkhorn_epsilon: float = 0.1,
sinkhorn_iters: int = 100,
)
Bases: OperatorConfig
Configuration for PASTE-style spatial transcriptomics slice alignment.
Attributes:
| Name | Type | Description |
|---|---|---|
alpha |
float
|
Balance between expression dissimilarity (linear term) and spatial Gromov-Wasserstein cost (quadratic term). 0 = pure expression matching, 1 = pure spatial structure matching. PASTE default: 0.1. |
sinkhorn_epsilon |
float
|
Entropy regularisation strength for the Sinkhorn optimal transport solver. |
sinkhorn_iters |
int
|
Number of Sinkhorn iterations. |
DifferentiableDifferentialDistribution¤
diffbio.operators.singlecell.differential_distribution.DifferentiableDifferentialDistribution
¤
DifferentiableDifferentialDistribution(
config: DifferentialDistributionConfig,
*,
rngs: Rngs | None = None,
name: str | None = None,
)
Bases: TemperatureOperator
Differentiable KS-test with learned pattern classification.
For each gene, this operator:
- Splits cells into two conditions based on binary condition labels.
- Computes a soft empirical CDF using sigmoid smoothing:
soft_CDF(x, values) = mean(sigmoid((x - values) / temperature)) - Computes a soft KS statistic as the smooth maximum of
|CDF_A(x) - CDF_B(x)|over evaluation points, using logsumexp. - Extracts distributional features (mean shift, variance ratio, zero-proportion difference) and passes them through a learned linear head to classify each gene into one of the pattern categories (shift, scale, both, none).
Inherits from TemperatureOperator to get:
- _temperature property for temperature-controlled smoothing
- soft_max() for logsumexp-based smooth maximum
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
DifferentialDistributionConfig
|
DifferentialDistributionConfig with model parameters. |
required |
rngs
|
Rngs | None
|
Flax NNX random number generators. |
None
|
name
|
str | None
|
Optional operator name. |
None
|
Example
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
DifferentialDistributionConfig
|
Differential distribution configuration. |
required |
rngs
|
Rngs | None
|
Random number generators for parameter initialisation. |
None
|
name
|
str | None
|
Optional operator name. |
None
|
apply
¤
apply(
data: PyTree,
state: PyTree,
metadata: dict[str, Any] | None,
random_params: Any = None,
stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]
Apply differentiable differential distribution testing.
For each gene, computes a soft KS statistic and classifies the distributional difference pattern using a learned linear head.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
PyTree
|
Dictionary containing: - "counts": Gene expression matrix (n_cells, n_genes) - "condition_labels": Binary condition labels (n_cells,) |
required |
state
|
PyTree
|
Element state (passed through unchanged). |
required |
metadata
|
dict[str, Any] | None
|
Element metadata (passed through unchanged). |
required |
random_params
|
Any
|
Not used. |
None
|
stats
|
dict[str, Any] | None
|
Not used. |
None
|
Returns:
| Type | Description |
|---|---|
tuple[PyTree, PyTree, dict[str, Any] | None]
|
Tuple of (transformed_data, state, metadata): - transformed_data contains:
|
DifferentialDistributionConfig¤
diffbio.operators.singlecell.differential_distribution.DifferentialDistributionConfig
dataclass
¤
DifferentialDistributionConfig(
n_genes: int = 2000,
temperature: float = 1.0,
learnable_temperature: bool = False,
n_pattern_classes: int = 4,
)
Bases: OperatorConfig
Configuration for differentiable differential distribution testing.
Attributes:
| Name | Type | Description |
|---|---|---|
n_genes |
int
|
Number of genes to analyse. |
temperature |
float
|
Temperature controlling sigmoid smoothness in the soft CDF and logsumexp soft max. Lower values yield sharper approximations closer to the true KS statistic. |
learnable_temperature |
bool
|
Whether temperature is a learnable parameter. |
n_pattern_classes |
int
|
Number of distributional pattern categories. Default 4 corresponds to (shift, scale, both, none). |
DifferentiableSimulator¤
diffbio.operators.singlecell.simulation.DifferentiableSimulator
¤
DifferentiableSimulator(
config: SimulationConfig,
*,
rngs: Rngs | None = None,
name: str | None = None,
)
Bases: OperatorModule
Splatter-style differentiable single-cell count simulator.
Generates realistic scRNA-seq count matrices following the Splatter generative model (Zappia et al., 2017), with all steps implemented as differentiable JAX operations.
Algorithm
- Gene means: softplus-transformed learnable logits, scaled by Gamma(shape, rate) random perturbation.
- Cell library sizes: LogNormal(lib_loc, lib_scale) sampling.
- Group assignments: cells divided evenly across groups, with learnable group logits enabling soft assignment.
- DE fold-changes: per-group per-gene LogNormal fold-changes, masked by a Bernoulli(de_prob) DE indicator.
- Cell means: lib_sizes * gene_means * group_fold_change * batch_effect.
- Batch effects: exp(learnable batch_shift) multiplicative scaling.
- Dropout: sigmoid-based keep probability as function of log(cell_means).
- Counts: cell_means * keep_prob (continuous relaxation of Poisson).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
SimulationConfig
|
SimulationConfig with model parameters. |
required |
rngs
|
Rngs | None
|
Flax NNX random number generators. |
None
|
name
|
str | None
|
Optional operator name. |
None
|
Example
config = SimulationConfig(n_cells=100, n_genes=50, n_groups=2) sim = DifferentiableSimulator(config, rngs=nnx.Rngs(0, sample=1)) rng = jax.random.key(0) rp = sim.generate_random_params(rng, {}) result, state, meta = sim.apply({}, {}, None, random_params=rp) result["counts"].shape (100, 50)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
SimulationConfig
|
Simulation configuration. |
required |
rngs
|
Rngs | None
|
Random number generators for parameter initialization. |
None
|
name
|
str | None
|
Optional operator name. |
None
|
apply
¤
apply(
data: PyTree,
state: PyTree,
metadata: dict[str, Any] | None,
random_params: Any = None,
stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]
Simulate a single-cell count matrix.
Follows the Splatter generative model with all steps differentiable: gene means, library sizes, group DE, batch effects, dropout, and Poisson count generation (continuous relaxation).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
PyTree
|
Input dictionary (may be empty; existing keys are preserved). |
required |
state
|
PyTree
|
Element state (passed through unchanged). |
required |
metadata
|
dict[str, Any] | None
|
Element metadata (passed through unchanged). |
required |
random_params
|
Any
|
Dictionary of JAX random keys from generate_random_params. |
None
|
stats
|
dict[str, Any] | None
|
Not used. |
None
|
Returns:
| Type | Description |
|---|---|
tuple[PyTree, PyTree, dict[str, Any] | None]
|
Tuple of (output_data, state, metadata) where output_data contains: - All original data keys preserved. - "counts": Simulated count matrix (n_cells, n_genes). - "group_labels": Hard group assignments (n_cells,). - "batch_labels": Batch assignments (n_cells,). - "gene_means": Per-gene expression means (n_genes,). - "de_mask": Binary DE indicator (n_groups, n_genes). |
SimulationConfig¤
diffbio.operators.singlecell.simulation.SimulationConfig
dataclass
¤
SimulationConfig(
dropout_mid: float = -1.0,
dropout_shape: float = -0.5,
mean_shape: float = 0.6,
mean_rate: float = 0.3,
lib_loc: float = 11.0,
lib_scale: float = 0.2,
de_prob: float = 0.1,
de_fac_loc: float = 0.1,
de_fac_scale: float = 0.4,
n_cells: int = 500,
n_genes: int = 200,
n_groups: int = 3,
n_batches: int = 1,
)
Bases: _SimulationSizeConfig, _SimulationDistributionConfig, _SimulationDropoutConfig, OperatorConfig
Configuration for DifferentiableSimulator.
DifferentiableArchetypalAnalysis¤
diffbio.operators.singlecell.archetypes.DifferentiableArchetypalAnalysis
¤
DifferentiableArchetypalAnalysis(
config: ArchetypalAnalysisConfig,
*,
rngs: Rngs | None = None,
name: str | None = None,
)
Bases: TemperatureOperator
Differentiable archetypal analysis with softmax simplex constraints.
Each cell is encoded into archetype weight space via an MLP, then temperature-controlled softmax produces simplex weights. The reconstruction is the convex combination of learnable archetype prototypes, enabling end-to-end gradient-based optimisation.
Inherits from TemperatureOperator to get:
_temperatureproperty for temperature-controlled smoothingsoft_max()for logsumexp-based smooth maximum
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
ArchetypalAnalysisConfig
|
ArchetypalAnalysisConfig with model parameters. |
required |
rngs
|
Rngs | None
|
Flax NNX random number generators. |
None
|
name
|
str | None
|
Optional operator name. |
None
|
Example
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
ArchetypalAnalysisConfig
|
Archetypal analysis configuration. |
required |
rngs
|
Rngs | None
|
Random number generators for weight initialisation. |
None
|
name
|
str | None
|
Optional operator name. |
None
|
apply
¤
apply(
data: PyTree,
state: PyTree,
metadata: dict[str, Any] | None,
random_params: Any = None,
stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]
Apply archetypal analysis to a cell-by-gene count matrix.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
PyTree
|
Dictionary containing:
- |
required |
state
|
PyTree
|
Element state (passed through unchanged). |
required |
metadata
|
dict[str, Any] | None
|
Element metadata (passed through unchanged). |
required |
random_params
|
Any
|
Not used. |
None
|
stats
|
dict[str, Any] | None
|
Not used. |
None
|
Returns:
| Type | Description |
|---|---|
PyTree
|
Tuple of |
PyTree
|
|
dict[str, Any] | None
|
|
tuple[PyTree, PyTree, dict[str, Any] | None]
|
|
tuple[PyTree, PyTree, dict[str, Any] | None]
|
|
tuple[PyTree, PyTree, dict[str, Any] | None]
|
|
ArchetypalAnalysisConfig¤
diffbio.operators.singlecell.archetypes.ArchetypalAnalysisConfig
dataclass
¤
ArchetypalAnalysisConfig(
n_genes: int = 2000,
n_archetypes: int = 5,
hidden_dim: int = 64,
temperature: float = 1.0,
learnable_temperature: bool = False,
)
Bases: OperatorConfig
Configuration for DifferentiableArchetypalAnalysis.
Attributes:
| Name | Type | Description |
|---|---|---|
n_genes |
int
|
Number of input genes (features per cell). |
n_archetypes |
int
|
Number of archetype prototypes to learn. |
hidden_dim |
int
|
Hidden dimension for the encoder MLP. |
temperature |
float
|
Softmax temperature (lower = sharper assignments). |
learnable_temperature |
bool
|
Whether temperature is a learnable parameter. |
DifferentiableOTTrajectory¤
diffbio.operators.singlecell.ot_trajectory.DifferentiableOTTrajectory
¤
DifferentiableOTTrajectory(
config: OTTrajectoryConfig,
*,
rngs: Rngs | None = None,
name: str | None = None,
)
Bases: OperatorModule
Waddington-OT-style differentiable trajectory inference.
Computes an optimal-transport plan between cell populations at two timepoints, estimates per-cell growth (proliferation) rates, and interpolates an intermediate cell distribution.
Algorithm
- Build the squared-Euclidean cost matrix
C[i,j] = ||x_i - y_j||^2between cells at t1 and t2. - Compute the entropy-regularised transport plan via
SinkhornLayer. - Derive growth rates from the transport plan: cells that transport to
more targets in t2 have higher proliferation. Normalise so that
mean(growth_rates) == 1. - Interpolate an intermediate distribution at time s:
x_interp = (1-s) * x_t1 + s * (T @ x_t2) / T.sum(axis=1)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
OTTrajectoryConfig
|
OTTrajectoryConfig with operator parameters. |
required |
rngs
|
Rngs | None
|
Flax NNX random number generators. |
None
|
name
|
str | None
|
Optional operator name. |
None
|
Example
config = OTTrajectoryConfig(n_genes=100, sinkhorn_iters=50) op = DifferentiableOTTrajectory(config) data = { ... "counts_t1": jnp.ones((20, 100)), ... "counts_t2": jnp.ones((25, 100)), ... } result, state, meta = op.apply(data, {}, None) result["transport_plan"].shape (20, 25)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
OTTrajectoryConfig
|
OT trajectory configuration. |
required |
rngs
|
Rngs | None
|
Random number generators (for API consistency). |
None
|
name
|
str | None
|
Optional operator name. |
None
|
apply
¤
apply(
data: PyTree,
state: PyTree,
metadata: dict[str, Any] | None,
random_params: Any = None,
stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]
Apply OT-based trajectory inference to two-timepoint expression data.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
PyTree
|
Dictionary containing:
- |
required |
state
|
PyTree
|
Element state (passed through unchanged). |
required |
metadata
|
dict[str, Any] | None
|
Element metadata (passed through unchanged). |
required |
random_params
|
Any
|
Not used (non-stochastic operator). |
None
|
stats
|
dict[str, Any] | None
|
Not used. |
None
|
Returns:
| Type | Description |
|---|---|
tuple[PyTree, PyTree, dict[str, Any] | None]
|
Tuple of (transformed_data, state, metadata): - transformed_data contains all original keys plus:
|
OTTrajectoryConfig¤
diffbio.operators.singlecell.ot_trajectory.OTTrajectoryConfig
dataclass
¤
OTTrajectoryConfig(
n_genes: int = 200,
sinkhorn_epsilon: float = 0.1,
sinkhorn_iters: int = 100,
growth_rate_regularization: float = 1.0,
interpolation_time: float = 0.5,
)
Bases: OperatorConfig
Configuration for OT-based trajectory inference.
Attributes:
| Name | Type | Description |
|---|---|---|
n_genes |
int
|
Number of input genes per cell. |
sinkhorn_epsilon |
float
|
Entropy regularisation strength for the Sinkhorn solver. Larger values produce smoother transport plans. |
sinkhorn_iters |
int
|
Number of Sinkhorn iterations. |
growth_rate_regularization |
float
|
Scaling factor applied to raw row-sums before normalisation. Higher values amplify growth-rate variation. |
interpolation_time |
float
|
Fraction in (0, 1) at which to compute the interpolated cell distribution between t1 and t2. |
Usage Examples¤
Soft K-Means Clustering¤
from flax import nnx
from diffbio.operators.singlecell import SoftKMeansClustering, SoftClusteringConfig
config = SoftClusteringConfig(n_clusters=10, n_features=50)
clustering = SoftKMeansClustering(config, rngs=nnx.Rngs(42))
data = {"embeddings": embeddings} # (n_cells, n_embeddings)
result, _, _ = clustering.apply(data, {}, None)
assignments = result["cluster_assignments"]
Batch Correction¤
from diffbio.operators.singlecell import DifferentiableHarmony, BatchCorrectionConfig
config = BatchCorrectionConfig(n_clusters=50, n_features=50)
harmony = DifferentiableHarmony(config, rngs=nnx.Rngs(42))
data = {"embeddings": embeddings, "batch_ids": batch_labels}
result, _, _ = harmony.apply(data, {}, None)
corrected = result["corrected_embeddings"]
RNA Velocity¤
from diffbio.operators.singlecell import DifferentiableVelocity, VelocityConfig
config = VelocityConfig(n_genes=2000, hidden_dim=64)
velocity = DifferentiableVelocity(config, rngs=nnx.Rngs(42))
data = {"spliced": spliced, "unspliced": unspliced}
result, _, _ = velocity.apply(data, {}, None)
vel = result["velocity"]