Skip to content

Statistical Operators API¤

Differentiable statistical operators for probabilistic modeling including HMMs, GLMs, and EM algorithms.

DifferentiableHMM¤

diffbio.operators.statistical.hmm.DifferentiableHMM ¤

DifferentiableHMM(
    config: HMMConfig,
    *,
    rngs: Rngs | None = None,
    name: str | None = None,
)

Bases: HMMOperator

Differentiable Hidden Markov Model.

This operator implements the HMM forward algorithm with differentiable operations, enabling gradient-based learning of transition and emission parameters.

The forward algorithm computes P(observations | model) using dynamic programming with logsumexp for numerical stability:

alpha[t, j] = sum_i(alpha[t-1, i] * A[i,j]) * B[j, o_t]

In log space: log_alpha[t, j] = logsumexp_i(log_alpha[t-1, i] + log_A[i,j]) + log_B[j, o_t]

Inherits from HMMOperator to get:

  • forward_pass() for likelihood computation
  • forward_backward_posteriors() for posterior computation
  • get_log_transition_matrix(), get_log_emission_matrix(), get_log_initial_distribution() for parameter access

Parameters:

Name Type Description Default
config HMMConfig

HMMConfig with model parameters.

required
rngs Rngs | None

Flax NNX random number generators.

None
name str | None

Optional operator name.

None
Example
config = HMMConfig(num_states=3, num_emissions=4)
hmm = DifferentiableHMM(config, rngs=nnx.Rngs(42))
data = {"observations": jnp.array([0, 1, 2, 3])}
result, state, meta = hmm.apply(data, {}, None)

Parameters:

Name Type Description Default
config HMMConfig

HMM configuration.

required
rngs Rngs | None

Random number generators for initialization.

None
name str | None

Optional operator name.

None

apply ¤

apply(
    data: PyTree,
    state: PyTree,
    metadata: dict[str, Any] | None,
    random_params: Any = None,
    stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]

Apply HMM to observation sequence.

This method computes the log-likelihood and state posteriors for a given observation sequence.

Parameters:

Name Type Description Default
data PyTree

Dictionary containing: - "observations": Integer-encoded observations (seq_len,)

required
state PyTree

Element state (passed through unchanged)

required
metadata dict[str, Any] | None

Element metadata (passed through unchanged)

required
random_params Any

Not used (deterministic operator)

None
stats dict[str, Any] | None

Not used

None

Returns:

Type Description
tuple[PyTree, PyTree, dict[str, Any] | None]

Tuple of (transformed_data, state, metadata): - transformed_data contains:

- "observations": Original observations
- "log_likelihood": Log probability of sequence
- "state_posteriors": P(state | observations) at each position
  • state is passed through unchanged
  • metadata is passed through unchanged

HMMConfig¤

diffbio.operators.statistical.hmm.HMMConfig dataclass ¤

HMMConfig(
    num_states: int = 3,
    num_emissions: int = 4,
    temperature: float = 1.0,
    learnable_transitions: bool = True,
    learnable_emissions: bool = True,
)

Bases: OperatorConfig

Configuration for DifferentiableHMM.

Attributes:

Name Type Description
num_states int

Number of hidden states.

num_emissions int

Number of possible emissions (e.g., 4 for DNA).

temperature float

Temperature for softmax operations.

learnable_transitions bool

Whether transition probabilities are learnable.

learnable_emissions bool

Whether emission probabilities are learnable.

DifferentiableNBGLM¤

diffbio.operators.statistical.nb_glm.DifferentiableNBGLM ¤

DifferentiableNBGLM(
    config: NBGLMConfig,
    *,
    rngs: Rngs | None = None,
    name: str | None = None,
)

Bases: OperatorModule

Differentiable Negative Binomial GLM for differential expression.

This operator implements a negative binomial GLM where: - log(mu) = X @ beta (design matrix @ coefficients) - P(count | mu, dispersion) = NB(count; mu, dispersion)

Gradients flow through both the coefficient (beta) and dispersion parameters, enabling end-to-end learning.

The negative binomial distribution is parameterized as: - mean = mu - variance = mu + mu^2 / dispersion

Parameters:

Name Type Description Default
config NBGLMConfig

NBGLMConfig with model parameters.

required
rngs Rngs | None

Flax NNX random number generators.

None
name str | None

Optional operator name.

None
Example
config = NBGLMConfig(n_features=2000, n_covariates=2)
glm = DifferentiableNBGLM(config, rngs=nnx.Rngs(42))
data = {"counts": counts, "design": design_row, "size_factor": sf}
result, state, meta = glm.apply(data, {}, None)

Parameters:

Name Type Description Default
config NBGLMConfig

NB GLM configuration.

required
rngs Rngs | None

Random number generators for initialization.

None
name str | None

Optional operator name.

None

apply ¤

apply(
    data: PyTree,
    state: PyTree,
    metadata: dict[str, Any] | None,
    random_params: Any = None,
    stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]

Apply NB GLM to count data.

This method computes the log likelihood and predicted mean for a given sample.

Parameters:

Name Type Description Default
data PyTree

Dictionary containing: - "counts": Gene counts (n_features,) - "design": Design matrix row (n_covariates,) - "size_factor": Library size factor (scalar)

required
state PyTree

Element state (passed through unchanged)

required
metadata dict[str, Any] | None

Element metadata (passed through unchanged)

required
random_params Any

Not used (deterministic operator)

None
stats dict[str, Any] | None

Not used

None

Returns:

Type Description
tuple[PyTree, PyTree, dict[str, Any] | None]

Tuple of (transformed_data, state, metadata): - transformed_data contains:

- "counts": Original counts
- "log_likelihood": Log probability of counts
- "predicted_mean": Predicted expression
- "dispersion": Dispersion parameters
  • state is passed through unchanged
  • metadata is passed through unchanged

NBGLMConfig¤

diffbio.operators.statistical.nb_glm.NBGLMConfig dataclass ¤

NBGLMConfig(
    n_features: int = 2000,
    n_covariates: int = 2,
    estimate_dispersion: bool = True,
)

Bases: OperatorConfig

Configuration for DifferentiableNBGLM.

Attributes:

Name Type Description
n_features int

Number of features (genes).

n_covariates int

Number of covariates in design matrix.

estimate_dispersion bool

Whether to estimate dispersion parameters.

DifferentiableEMQuantifier¤

diffbio.operators.statistical.em_quantification.DifferentiableEMQuantifier ¤

DifferentiableEMQuantifier(
    config: EMQuantifierConfig,
    *,
    rngs: Rngs | None = None,
    name: str | None = None,
)

Bases: TemperatureOperator

Differentiable EM for transcript quantification.

This operator implements the EM algorithm for estimating transcript abundances from read-to-transcript compatibility data. The fixed number of iterations enables gradient flow through the entire quantification process.

Algorithm: 1. Initialize abundances (learnable prior) 2. E-step: Probabilistic assignment of reads to transcripts weights = softmax(compatibility * abundances / temperature) 3. M-step: Update abundances from weighted counts abundances = sum(weights) / effective_lengths abundances = abundances / sum(abundances) 4. Repeat for n_iterations

Inherits from TemperatureOperator to get:

  • _temperature property for temperature-controlled smoothing
  • soft_max() for logsumexp-based smooth maximum
  • soft_argmax() for soft position selection

Parameters:

Name Type Description Default
config EMQuantifierConfig

EMQuantifierConfig with model parameters.

required
rngs Rngs | None

Flax NNX random number generators.

None
name str | None

Optional operator name.

None
Example
config = EMQuantifierConfig(n_transcripts=1000, n_iterations=10)
quantifier = DifferentiableEMQuantifier(config, rngs=nnx.Rngs(42))
data = {"compatibility": compat_matrix, "effective_lengths": eff_lens}
result, state, meta = quantifier.apply(data, {}, None)

Parameters:

Name Type Description Default
config EMQuantifierConfig

EM quantifier configuration.

required
rngs Rngs | None

Random number generators for initialization.

None
name str | None

Optional operator name.

None

apply ¤

apply(
    data: PyTree,
    state: PyTree,
    metadata: dict[str, Any] | None,
    random_params: Any = None,
    stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]

Apply EM quantification to read assignment data.

This method runs the EM algorithm to estimate transcript abundances from read-transcript compatibility data.

Parameters:

Name Type Description Default
data PyTree

Dictionary containing: - "compatibility": Read-transcript compatibility matrix (n_reads, n_transcripts) - "effective_lengths": Effective transcript lengths (n_transcripts,)

required
state PyTree

Element state (passed through unchanged)

required
metadata dict[str, Any] | None

Element metadata (passed through unchanged)

required
random_params Any

Not used (deterministic operator)

None
stats dict[str, Any] | None

Not used

None

Returns:

Type Description
tuple[PyTree, PyTree, dict[str, Any] | None]

Tuple of (transformed_data, state, metadata): - transformed_data contains:

- "compatibility": Original compatibility matrix
- "effective_lengths": Original effective lengths
- "abundances": Estimated transcript abundances
- "tpm": TPM (Transcripts Per Million) values
  • state is passed through unchanged
  • metadata is passed through unchanged

EMQuantifierConfig¤

diffbio.operators.statistical.em_quantification.EMQuantifierConfig dataclass ¤

EMQuantifierConfig(
    n_transcripts: int = 1000,
    n_iterations: int = 10,
    temperature: float = 1.0,
)

Bases: OperatorConfig

Configuration for DifferentiableEMQuantifier.

Attributes:

Name Type Description
n_transcripts int

Number of transcripts to quantify.

n_iterations int

Fixed number of EM iterations (for unrolling).

temperature float

Temperature for softmax in E-step.

Usage Examples¤

Hidden Markov Model¤

from flax import nnx
from diffbio.operators.statistical import DifferentiableHMM, HMMConfig

config = HMMConfig(num_states=5, num_emissions=4)
hmm = DifferentiableHMM(config, rngs=nnx.Rngs(42))

data = {"observations": observations}  # (seq_length, n_observations)
result, _, _ = hmm.apply(data, {}, None)
log_likelihood = result["log_likelihood"]
posteriors = result["posteriors"]

Negative Binomial GLM¤

from diffbio.operators.statistical import DifferentiableNBGLM, NBGLMConfig

config = NBGLMConfig(n_features=2000, n_covariates=10)
nbglm = DifferentiableNBGLM(config, rngs=nnx.Rngs(42))

data = {
    "counts": count_matrix,
    "design": design_matrix,
    "size_factors": size_factors,
}
result, _, _ = nbglm.apply(data, {}, None)
coefficients = result["coefficients"]

EM Quantifier¤

from diffbio.operators.statistical import DifferentiableEMQuantifier, EMQuantifierConfig

config = EMQuantifierConfig(n_transcripts=50000, n_iterations=100)
em_quant = DifferentiableEMQuantifier(config, rngs=nnx.Rngs(42))

data = {
    "equivalence_classes": eq_classes,
    "counts": eq_counts,
    "effective_lengths": eff_lengths,
}
result, _, _ = em_quant.apply(data, {}, None)
tpm = result["tpm"]