Normalization Operators API¤
Differentiable normalization operators for count data, dimensionality reduction, and embeddings.
VAENormalizer¤
diffbio.operators.normalization.vae_normalizer.VAENormalizer
¤
VAENormalizer(
config: VAENormalizerConfig,
*,
rngs: Rngs | None = None,
name: str | None = None,
)
Bases: CountReconstructionMixin, CountVAEBackboneMixin, EncoderDecoderOperator
Variational autoencoder for count normalization.
This operator learns a low-dimensional latent representation of single-cell gene expression data while accounting for technical factors like library size.
The model: - Encoder: counts -> latent (mean, logvar) - Reparameterization: z = mean + exp(0.5 * logvar) * epsilon - Decoder: z -> gene expression rates (and optionally dispersion/dropout)
Supports two likelihood models: - Poisson: Simple count model (default) - ZINB: Zero-Inflated Negative Binomial for overdispersed data with excess zeros, as used in scVI
Inherits from EncoderDecoderOperator to get:
- reparameterize() for sampling with reparameterization trick
- kl_divergence() for KL from standard normal
- elbo_loss() for combining reconstruction and KL losses
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
VAENormalizerConfig
|
VAENormalizerConfig with model parameters. |
required |
rngs
|
Rngs | None
|
Flax NNX random number generators. |
None
|
name
|
str | None
|
Optional operator name. |
None
|
Example
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
VAENormalizerConfig
|
VAE configuration. |
required |
rngs
|
Rngs | None
|
Random number generators for initialization and sampling. |
None
|
name
|
str | None
|
Optional operator name. |
None
|
apply
¤
apply(
data: PyTree,
state: PyTree,
metadata: dict[str, Any] | None,
random_params: Any = None,
stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]
Apply VAE normalization to count data.
This method encodes the counts to latent space, samples a latent representation, and decodes to normalized expression.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
PyTree
|
Dictionary containing: - "counts": Gene expression counts (n_genes,) - "library_size": Total counts for the cell |
required |
state
|
PyTree
|
Element state (passed through unchanged) |
required |
metadata
|
dict[str, Any] | None
|
Element metadata (passed through unchanged) |
required |
random_params
|
Any
|
Optional random parameters (not used) |
None
|
stats
|
dict[str, Any] | None
|
Not used |
None
|
Returns:
| Type | Description |
|---|---|
tuple[PyTree, PyTree, dict[str, Any] | None]
|
Tuple of (transformed_data, state, metadata): - transformed_data contains:
|
VAENormalizerConfig¤
diffbio.operators.normalization.vae_normalizer.VAENormalizerConfig
dataclass
¤
VAENormalizerConfig(
latent_dim: int = 10,
hidden_dims: list[int] = (lambda: [128, 64])(),
n_genes: int = 2000,
use_batch_correction: bool = False,
likelihood: Literal["poisson", "zinb"] = "poisson",
)
Bases: OperatorConfig
Configuration for VAENormalizer.
Attributes:
| Name | Type | Description |
|---|---|---|
latent_dim |
int
|
Dimension of latent space. |
hidden_dims |
list[int]
|
Hidden layer dimensions for encoder/decoder. |
n_genes |
int
|
Number of genes (input/output dimension). |
use_batch_correction |
bool
|
Whether to include batch effects. |
likelihood |
Literal['poisson', 'zinb']
|
Likelihood model for reconstruction loss. 'poisson' for standard Poisson NLL, 'zinb' for Zero-Inflated Negative Binomial. |
DifferentiableUMAP¤
diffbio.operators.normalization.umap.DifferentiableUMAP
¤
DifferentiableUMAP(
config: UMAPConfig, *, rngs: Rngs | None = None
)
Bases: OperatorModule
Differentiable UMAP for dimensionality reduction.
This operator implements a simplified differentiable version of UMAP that learns a low-dimensional embedding while preserving local structure.
The UMAP loss function is
L = sum_edges [p_ij * log(q_ij) + (1 - p_ij) * log(1 - q_ij)]
where
- p_ij is the high-dimensional similarity (fuzzy set membership)
- q_ij is the low-dimensional similarity
This implementation uses a parametric approach with learnable curve parameters (a, b) for the low-dimensional similarity function.
Example
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
UMAPConfig
|
Configuration for UMAP. |
required |
rngs
|
Rngs | None
|
Random number generators for initialization. |
None
|
apply
¤
apply(
data: dict[str, Any],
state: dict[str, Any],
metadata: dict | None,
random_params: dict | None = None,
stats: dict | None = None,
) -> tuple[dict, dict, dict | None]
Apply UMAP dimensionality reduction.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
dict[str, Any]
|
Dictionary containing: - 'features': High-dimensional features of shape (n_samples, n_features) |
required |
state
|
dict[str, Any]
|
Operator state dictionary. |
required |
metadata
|
dict | None
|
Optional metadata dictionary. |
required |
random_params
|
dict | None
|
Optional random parameters (unused). |
None
|
stats
|
dict | None
|
Optional statistics dictionary (unused). |
None
|
Returns:
| Type | Description |
|---|---|
tuple[dict, dict, dict | None]
|
Tuple of (output_data, state, metadata) where output_data contains:
|
UMAPConfig¤
diffbio.operators.normalization.umap.UMAPConfig
dataclass
¤
UMAPConfig(
n_components: int = 2,
n_neighbors: int = 15,
metric: str = "euclidean",
input_features: int = 64,
hidden_dim: int = 32,
)
Bases: OperatorConfig
Configuration for differentiable UMAP.
Attributes:
| Name | Type | Description |
|---|---|---|
n_components |
int
|
Number of dimensions in the embedding. |
n_neighbors |
int
|
Number of neighbors for local structure preservation. |
metric |
str
|
Distance metric ('euclidean' or 'cosine'). |
input_features |
int
|
Number of input features (required for initialization). |
hidden_dim |
int
|
Hidden dimension for projection network. |
stream_name |
int
|
Name of the data stream to process. |
SequenceEmbedding¤
diffbio.operators.normalization.embedding.SequenceEmbedding
¤
SequenceEmbedding(
config: SequenceEmbeddingConfig,
*,
rngs: Rngs | None = None,
name: str | None = None,
)
Bases: OperatorModule
Convolutional sequence embedding operator.
This operator converts one-hot encoded DNA sequences into dense embeddings using a stack of 1D convolutions followed by global average pooling.
The architecture: 1. Input: one-hot sequence (length, 4) 2. 1D convolutions with ReLU activation 3. Per-position features (length, embedding_dim) 4. Global average pooling -> fixed embedding (embedding_dim,)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
SequenceEmbeddingConfig
|
SequenceEmbeddingConfig with model parameters. |
required |
rngs
|
Rngs | None
|
Flax NNX random number generators. |
None
|
name
|
str | None
|
Optional operator name. |
None
|
Example
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
SequenceEmbeddingConfig
|
Embedding configuration. |
required |
rngs
|
Rngs | None
|
Random number generators for initialization. |
None
|
name
|
str | None
|
Optional operator name. |
None
|
apply
¤
apply(
data: PyTree,
state: PyTree,
metadata: dict[str, Any] | None,
random_params: Any = None,
stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]
Apply sequence embedding to sequence data.
This method extracts dense embeddings from one-hot encoded DNA sequences using convolutional feature extraction.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
PyTree
|
Dictionary containing: - "sequence": One-hot encoded sequence (length, alphabet_size) |
required |
state
|
PyTree
|
Element state (passed through unchanged) |
required |
metadata
|
dict[str, Any] | None
|
Element metadata (passed through unchanged) |
required |
random_params
|
Any
|
Not used (deterministic operator) |
None
|
stats
|
dict[str, Any] | None
|
Not used |
None
|
Returns:
| Type | Description |
|---|---|
tuple[PyTree, PyTree, dict[str, Any] | None]
|
Tuple of (transformed_data, state, metadata): - transformed_data contains:
|
SequenceEmbeddingConfig¤
diffbio.operators.normalization.embedding.SequenceEmbeddingConfig
dataclass
¤
SequenceEmbeddingConfig(
embedding_dim: int = 64,
method: str = "conv",
kernel_size: int = 7,
num_conv_layers: int = 3,
)
Bases: OperatorConfig
Configuration for SequenceEmbedding.
Attributes:
| Name | Type | Description |
|---|---|---|
embedding_dim |
int
|
Dimension of output embedding. |
method |
str
|
Embedding method ("conv" for convolutional). |
kernel_size |
int
|
Convolution kernel size. |
num_conv_layers |
int
|
Number of convolutional layers. |
DifferentiablePHATE¤
diffbio.operators.normalization.phate.DifferentiablePHATE
¤
DifferentiablePHATE(
config: PHATEConfig, *, rngs: Rngs | None = None
)
Bases: OperatorModule
Differentiable PHATE for dimensionality reduction.
Implements the full PHATE pipeline in a differentiable manner using JAX:
- Pairwise distances via
compute_pairwise_distances(DRY). - Alpha-decay affinity kernel:
K(i,j) = exp(-(d(i,j)/sigma_i)^decay)wheresigma_iis the distance to the k-th neighbor. - Symmetrize via
symmetrize_graph(DRY). - Row-normalize to Markov matrix
M. - Diffusion
M^tvia eigendecomposition. - Potential distance:
-log(M^t + eps)forgamma=1(log), or(M^t)^((1-gamma)/2) / ((1-gamma)/2)otherwise. - Classical MDS on the potential distance matrix: center, eigendecompose,
take top
n_componentseigenvectors.
Example
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
PHATEConfig
|
Configuration for PHATE. |
required |
rngs
|
Rngs | None
|
Random number generators for parameter initialization. |
None
|
apply
¤
apply(
data: dict[str, Any],
state: dict[str, Any],
metadata: dict | None,
random_params: dict | None = None,
stats: dict | None = None,
) -> tuple[dict, dict, dict | None]
Apply PHATE dimensionality reduction.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
dict[str, Any]
|
Dictionary containing:
- |
required |
state
|
dict[str, Any]
|
Operator state dictionary. |
required |
metadata
|
dict | None
|
Optional metadata dictionary. |
required |
random_params
|
dict | None
|
Optional random parameters (unused). |
None
|
stats
|
dict | None
|
Optional statistics dictionary (unused). |
None
|
Returns:
| Type | Description |
|---|---|
tuple[dict, dict, dict | None]
|
Tuple of
|
PHATEConfig¤
diffbio.operators.normalization.phate.PHATEConfig
dataclass
¤
PHATEConfig(
n_components: int = 2,
n_neighbors: int = 5,
decay: float = 40.0,
diffusion_t: int = 10,
gamma: float = 1.0,
input_features: int = 64,
hidden_dim: int = 32,
)
Bases: OperatorConfig
Configuration for differentiable PHATE.
Attributes:
| Name | Type | Description |
|---|---|---|
n_components |
int
|
Number of dimensions in the embedding. |
n_neighbors |
int
|
Number of nearest neighbors for local bandwidth. |
decay |
float
|
Exponent for the alpha-decaying kernel. Higher values produce sharper kernel tails (PHATE default 40). |
diffusion_t |
int
|
Power to which the diffusion operator is raised. Controls the level of diffusion smoothing. |
gamma |
float
|
Informational distance constant. |
input_features |
int
|
Number of input features (used for projection network). |
hidden_dim |
int
|
Hidden dimension for the projection network. |
Usage Examples¤
VAE Normalization¤
from flax import nnx
from diffbio.operators.normalization import VAENormalizer, VAENormalizerConfig
config = VAENormalizerConfig(n_genes=2000, latent_dim=10)
vae = VAENormalizer(config, rngs=nnx.Rngs(42))
data = {"counts": raw_counts} # (n_cells, n_genes)
result, _, _ = vae.apply(data, {}, None)
normalized = result["normalized"]
UMAP Dimensionality Reduction¤
from diffbio.operators.normalization import DifferentiableUMAP, UMAPConfig
config = UMAPConfig(n_components=2, n_neighbors=15, input_features=50)
umap = DifferentiableUMAP(config, rngs=nnx.Rngs(42))
data = {"features": high_dim_data} # (n_samples, n_features)
result, _, _ = umap.apply(data, {}, None)
embedding = result["embedding"]
Sequence Embedding¤
from diffbio.operators.normalization import SequenceEmbedding, SequenceEmbeddingConfig
config = SequenceEmbeddingConfig(embedding_dim=64, max_length=100)
seq_embed = SequenceEmbedding(config, rngs=nnx.Rngs(42))
data = {"sequences": sequences}
result, _, _ = seq_embed.apply(data, {}, None)
embeddings = result["embeddings"]