Skip to content

Normalization Operators API¤

Differentiable normalization operators for count data, dimensionality reduction, and embeddings.

VAENormalizer¤

diffbio.operators.normalization.vae_normalizer.VAENormalizer ¤

VAENormalizer(
    config: VAENormalizerConfig,
    *,
    rngs: Rngs | None = None,
    name: str | None = None,
)

Bases: CountReconstructionMixin, CountVAEBackboneMixin, EncoderDecoderOperator

Variational autoencoder for count normalization.

This operator learns a low-dimensional latent representation of single-cell gene expression data while accounting for technical factors like library size.

The model: - Encoder: counts -> latent (mean, logvar) - Reparameterization: z = mean + exp(0.5 * logvar) * epsilon - Decoder: z -> gene expression rates (and optionally dispersion/dropout)

Supports two likelihood models: - Poisson: Simple count model (default) - ZINB: Zero-Inflated Negative Binomial for overdispersed data with excess zeros, as used in scVI

Inherits from EncoderDecoderOperator to get:

  • reparameterize() for sampling with reparameterization trick
  • kl_divergence() for KL from standard normal
  • elbo_loss() for combining reconstruction and KL losses

Parameters:

Name Type Description Default
config VAENormalizerConfig

VAENormalizerConfig with model parameters.

required
rngs Rngs | None

Flax NNX random number generators.

None
name str | None

Optional operator name.

None
Example
config = VAENormalizerConfig(n_genes=2000, latent_dim=10)
normalizer = VAENormalizer(config, rngs=nnx.Rngs(42))
data = {"counts": counts, "library_size": lib_size}
result, state, meta = normalizer.apply(data, {}, None)

Parameters:

Name Type Description Default
config VAENormalizerConfig

VAE configuration.

required
rngs Rngs | None

Random number generators for initialization and sampling.

None
name str | None

Optional operator name.

None

apply ¤

apply(
    data: PyTree,
    state: PyTree,
    metadata: dict[str, Any] | None,
    random_params: Any = None,
    stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]

Apply VAE normalization to count data.

This method encodes the counts to latent space, samples a latent representation, and decodes to normalized expression.

Parameters:

Name Type Description Default
data PyTree

Dictionary containing: - "counts": Gene expression counts (n_genes,) - "library_size": Total counts for the cell

required
state PyTree

Element state (passed through unchanged)

required
metadata dict[str, Any] | None

Element metadata (passed through unchanged)

required
random_params Any

Optional random parameters (not used)

None
stats dict[str, Any] | None

Not used

None

Returns:

Type Description
tuple[PyTree, PyTree, dict[str, Any] | None]

Tuple of (transformed_data, state, metadata): - transformed_data contains:

- "counts": Original counts
- "normalized": Normalized expression
- "latent_z": Sampled latent representation
- "latent_mean": Mean of latent distribution
- "latent_logvar": Log variance of latent distribution
- "log_rate": Decoded log rates
  • state is passed through unchanged
  • metadata is passed through unchanged

VAENormalizerConfig¤

diffbio.operators.normalization.vae_normalizer.VAENormalizerConfig dataclass ¤

VAENormalizerConfig(
    latent_dim: int = 10,
    hidden_dims: list[int] = (lambda: [128, 64])(),
    n_genes: int = 2000,
    use_batch_correction: bool = False,
    likelihood: Literal["poisson", "zinb"] = "poisson",
)

Bases: OperatorConfig

Configuration for VAENormalizer.

Attributes:

Name Type Description
latent_dim int

Dimension of latent space.

hidden_dims list[int]

Hidden layer dimensions for encoder/decoder.

n_genes int

Number of genes (input/output dimension).

use_batch_correction bool

Whether to include batch effects.

likelihood Literal['poisson', 'zinb']

Likelihood model for reconstruction loss. 'poisson' for standard Poisson NLL, 'zinb' for Zero-Inflated Negative Binomial.

DifferentiableUMAP¤

diffbio.operators.normalization.umap.DifferentiableUMAP ¤

DifferentiableUMAP(
    config: UMAPConfig, *, rngs: Rngs | None = None
)

Bases: OperatorModule

Differentiable UMAP for dimensionality reduction.

This operator implements a simplified differentiable version of UMAP that learns a low-dimensional embedding while preserving local structure.

The UMAP loss function is

L = sum_edges [p_ij * log(q_ij) + (1 - p_ij) * log(1 - q_ij)]

where
  • p_ij is the high-dimensional similarity (fuzzy set membership)
  • q_ij is the low-dimensional similarity

This implementation uses a parametric approach with learnable curve parameters (a, b) for the low-dimensional similarity function.

Example
config = UMAPConfig(
    n_components=2,
    n_neighbors=15,
)
umap = DifferentiableUMAP(config, rngs=rngs)

data = {"features": high_dim_data}  # (n_samples, n_features)
result, state, metadata = umap.apply(data, {}, None)
embedding = result["embedding"]  # (n_samples, n_components)

Parameters:

Name Type Description Default
config UMAPConfig

Configuration for UMAP.

required
rngs Rngs | None

Random number generators for initialization.

None

apply ¤

apply(
    data: dict[str, Any],
    state: dict[str, Any],
    metadata: dict | None,
    random_params: dict | None = None,
    stats: dict | None = None,
) -> tuple[dict, dict, dict | None]

Apply UMAP dimensionality reduction.

Parameters:

Name Type Description Default
data dict[str, Any]

Dictionary containing: - 'features': High-dimensional features of shape (n_samples, n_features)

required
state dict[str, Any]

Operator state dictionary.

required
metadata dict | None

Optional metadata dictionary.

required
random_params dict | None

Optional random parameters (unused).

None
stats dict | None

Optional statistics dictionary (unused).

None

Returns:

Type Description
tuple[dict, dict, dict | None]

Tuple of (output_data, state, metadata) where output_data contains:

  • 'features': Original high-dimensional features
  • 'embedding': Low-dimensional embedding
  • 'high_dim_similarities': Fuzzy set memberships (p_ij)
  • 'low_dim_similarities': Embedding similarities (q_ij)

UMAPConfig¤

diffbio.operators.normalization.umap.UMAPConfig dataclass ¤

UMAPConfig(
    n_components: int = 2,
    n_neighbors: int = 15,
    metric: str = "euclidean",
    input_features: int = 64,
    hidden_dim: int = 32,
)

Bases: OperatorConfig

Configuration for differentiable UMAP.

Attributes:

Name Type Description
n_components int

Number of dimensions in the embedding.

n_neighbors int

Number of neighbors for local structure preservation.

metric str

Distance metric ('euclidean' or 'cosine').

input_features int

Number of input features (required for initialization).

hidden_dim int

Hidden dimension for projection network.

stream_name int

Name of the data stream to process.

SequenceEmbedding¤

diffbio.operators.normalization.embedding.SequenceEmbedding ¤

SequenceEmbedding(
    config: SequenceEmbeddingConfig,
    *,
    rngs: Rngs | None = None,
    name: str | None = None,
)

Bases: OperatorModule

Convolutional sequence embedding operator.

This operator converts one-hot encoded DNA sequences into dense embeddings using a stack of 1D convolutions followed by global average pooling.

The architecture: 1. Input: one-hot sequence (length, 4) 2. 1D convolutions with ReLU activation 3. Per-position features (length, embedding_dim) 4. Global average pooling -> fixed embedding (embedding_dim,)

Parameters:

Name Type Description Default
config SequenceEmbeddingConfig

SequenceEmbeddingConfig with model parameters.

required
rngs Rngs | None

Flax NNX random number generators.

None
name str | None

Optional operator name.

None
Example
config = SequenceEmbeddingConfig(embedding_dim=64)
embedder = SequenceEmbedding(config, rngs=nnx.Rngs(42))
data = {"sequence": encoded_seq}
result, state, meta = embedder.apply(data, {}, None)

Parameters:

Name Type Description Default
config SequenceEmbeddingConfig

Embedding configuration.

required
rngs Rngs | None

Random number generators for initialization.

None
name str | None

Optional operator name.

None

apply ¤

apply(
    data: PyTree,
    state: PyTree,
    metadata: dict[str, Any] | None,
    random_params: Any = None,
    stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]

Apply sequence embedding to sequence data.

This method extracts dense embeddings from one-hot encoded DNA sequences using convolutional feature extraction.

Parameters:

Name Type Description Default
data PyTree

Dictionary containing: - "sequence": One-hot encoded sequence (length, alphabet_size)

required
state PyTree

Element state (passed through unchanged)

required
metadata dict[str, Any] | None

Element metadata (passed through unchanged)

required
random_params Any

Not used (deterministic operator)

None
stats dict[str, Any] | None

Not used

None

Returns:

Type Description
tuple[PyTree, PyTree, dict[str, Any] | None]

Tuple of (transformed_data, state, metadata): - transformed_data contains:

- "sequence": Original sequence
- "embedding": Global sequence embedding (embedding_dim,)
- "position_embeddings": Per-position features (length, embedding_dim)
  • state is passed through unchanged
  • metadata is passed through unchanged

SequenceEmbeddingConfig¤

diffbio.operators.normalization.embedding.SequenceEmbeddingConfig dataclass ¤

SequenceEmbeddingConfig(
    embedding_dim: int = 64,
    method: str = "conv",
    kernel_size: int = 7,
    num_conv_layers: int = 3,
)

Bases: OperatorConfig

Configuration for SequenceEmbedding.

Attributes:

Name Type Description
embedding_dim int

Dimension of output embedding.

method str

Embedding method ("conv" for convolutional).

kernel_size int

Convolution kernel size.

num_conv_layers int

Number of convolutional layers.

DifferentiablePHATE¤

diffbio.operators.normalization.phate.DifferentiablePHATE ¤

DifferentiablePHATE(
    config: PHATEConfig, *, rngs: Rngs | None = None
)

Bases: OperatorModule

Differentiable PHATE for dimensionality reduction.

Implements the full PHATE pipeline in a differentiable manner using JAX:

  1. Pairwise distances via compute_pairwise_distances (DRY).
  2. Alpha-decay affinity kernel: K(i,j) = exp(-(d(i,j)/sigma_i)^decay) where sigma_i is the distance to the k-th neighbor.
  3. Symmetrize via symmetrize_graph (DRY).
  4. Row-normalize to Markov matrix M.
  5. Diffusion M^t via eigendecomposition.
  6. Potential distance: -log(M^t + eps) for gamma=1 (log), or (M^t)^((1-gamma)/2) / ((1-gamma)/2) otherwise.
  7. Classical MDS on the potential distance matrix: center, eigendecompose, take top n_components eigenvectors.
Example
config = PHATEConfig(n_components=2, n_neighbors=5, diffusion_t=10)
phate = DifferentiablePHATE(config, rngs=rngs)

data = {"features": high_dim_data}  # (n_samples, n_features)
result, state, metadata = phate.apply(data, {}, None)
embedding = result["embedding"]  # (n_samples, n_components)

Parameters:

Name Type Description Default
config PHATEConfig

Configuration for PHATE.

required
rngs Rngs | None

Random number generators for parameter initialization.

None

apply ¤

apply(
    data: dict[str, Any],
    state: dict[str, Any],
    metadata: dict | None,
    random_params: dict | None = None,
    stats: dict | None = None,
) -> tuple[dict, dict, dict | None]

Apply PHATE dimensionality reduction.

Parameters:

Name Type Description Default
data dict[str, Any]

Dictionary containing: - "features": High-dimensional features (n_samples, n_features)

required
state dict[str, Any]

Operator state dictionary.

required
metadata dict | None

Optional metadata dictionary.

required
random_params dict | None

Optional random parameters (unused).

None
stats dict | None

Optional statistics dictionary (unused).

None

Returns:

Type Description
tuple[dict, dict, dict | None]

Tuple of (output_data, state, metadata) where output_data contains:

  • "features": Original high-dimensional features
  • "embedding": Low-dimensional PHATE embedding (n_samples, n_components)
  • "potential_distances": Symmetric potential distance matrix (n_samples, n_samples)
  • "diffusion_operator": Row-stochastic diffusion matrix M^t (n_samples, n_samples)

PHATEConfig¤

diffbio.operators.normalization.phate.PHATEConfig dataclass ¤

PHATEConfig(
    n_components: int = 2,
    n_neighbors: int = 5,
    decay: float = 40.0,
    diffusion_t: int = 10,
    gamma: float = 1.0,
    input_features: int = 64,
    hidden_dim: int = 32,
)

Bases: OperatorConfig

Configuration for differentiable PHATE.

Attributes:

Name Type Description
n_components int

Number of dimensions in the embedding.

n_neighbors int

Number of nearest neighbors for local bandwidth.

decay float

Exponent for the alpha-decaying kernel. Higher values produce sharper kernel tails (PHATE default 40).

diffusion_t int

Power to which the diffusion operator is raised. Controls the level of diffusion smoothing.

gamma float

Informational distance constant. gamma=1 gives the log potential, gamma=0 gives the sqrt potential.

input_features int

Number of input features (used for projection network).

hidden_dim int

Hidden dimension for the projection network.

Usage Examples¤

VAE Normalization¤

from flax import nnx
from diffbio.operators.normalization import VAENormalizer, VAENormalizerConfig

config = VAENormalizerConfig(n_genes=2000, latent_dim=10)
vae = VAENormalizer(config, rngs=nnx.Rngs(42))

data = {"counts": raw_counts}  # (n_cells, n_genes)
result, _, _ = vae.apply(data, {}, None)
normalized = result["normalized"]

UMAP Dimensionality Reduction¤

from diffbio.operators.normalization import DifferentiableUMAP, UMAPConfig

config = UMAPConfig(n_components=2, n_neighbors=15, input_features=50)
umap = DifferentiableUMAP(config, rngs=nnx.Rngs(42))

data = {"features": high_dim_data}  # (n_samples, n_features)
result, _, _ = umap.apply(data, {}, None)
embedding = result["embedding"]

Sequence Embedding¤

from diffbio.operators.normalization import SequenceEmbedding, SequenceEmbeddingConfig

config = SequenceEmbeddingConfig(embedding_dim=64, max_length=100)
seq_embed = SequenceEmbedding(config, rngs=nnx.Rngs(42))

data = {"sequences": sequences}
result, _, _ = seq_embed.apply(data, {}, None)
embeddings = result["embeddings"]