Statistical Operators API¤
Differentiable statistical operators for probabilistic modeling including HMMs, GLMs, and EM algorithms.
DifferentiableHMM¤
diffbio.operators.statistical.hmm.DifferentiableHMM
¤
DifferentiableHMM(
config: HMMConfig,
*,
rngs: Rngs | None = None,
name: str | None = None,
)
Bases: HMMOperator
Differentiable Hidden Markov Model.
This operator implements the HMM forward algorithm with differentiable operations, enabling gradient-based learning of transition and emission parameters.
The forward algorithm computes P(observations | model) using dynamic programming with logsumexp for numerical stability:
alpha[t, j] = sum_i(alpha[t-1, i] * A[i,j]) * B[j, o_t]
In log space: log_alpha[t, j] = logsumexp_i(log_alpha[t-1, i] + log_A[i,j]) + log_B[j, o_t]
Inherits from HMMOperator to get:
- forward_pass() for likelihood computation
- forward_backward_posteriors() for posterior computation
- get_log_transition_matrix(), get_log_emission_matrix(), get_log_initial_distribution() for parameter access
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
HMMConfig
|
HMMConfig with model parameters. |
required |
rngs
|
Rngs | None
|
Flax NNX random number generators. |
None
|
name
|
str | None
|
Optional operator name. |
None
|
Example
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
HMMConfig
|
HMM configuration. |
required |
rngs
|
Rngs | None
|
Random number generators for initialization. |
None
|
name
|
str | None
|
Optional operator name. |
None
|
apply
¤
apply(
data: PyTree,
state: PyTree,
metadata: dict[str, Any] | None,
random_params: Any = None,
stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]
Apply HMM to observation sequence.
This method computes the log-likelihood and state posteriors for a given observation sequence.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
PyTree
|
Dictionary containing: - "observations": Integer-encoded observations (seq_len,) |
required |
state
|
PyTree
|
Element state (passed through unchanged) |
required |
metadata
|
dict[str, Any] | None
|
Element metadata (passed through unchanged) |
required |
random_params
|
Any
|
Not used (deterministic operator) |
None
|
stats
|
dict[str, Any] | None
|
Not used |
None
|
Returns:
| Type | Description |
|---|---|
tuple[PyTree, PyTree, dict[str, Any] | None]
|
Tuple of (transformed_data, state, metadata): - transformed_data contains:
|
HMMConfig¤
diffbio.operators.statistical.hmm.HMMConfig
dataclass
¤
HMMConfig(
num_states: int = 3,
num_emissions: int = 4,
temperature: float = 1.0,
learnable_transitions: bool = True,
learnable_emissions: bool = True,
)
Bases: OperatorConfig
Configuration for DifferentiableHMM.
Attributes:
| Name | Type | Description |
|---|---|---|
num_states |
int
|
Number of hidden states. |
num_emissions |
int
|
Number of possible emissions (e.g., 4 for DNA). |
temperature |
float
|
Temperature for softmax operations. |
learnable_transitions |
bool
|
Whether transition probabilities are learnable. |
learnable_emissions |
bool
|
Whether emission probabilities are learnable. |
DifferentiableNBGLM¤
diffbio.operators.statistical.nb_glm.DifferentiableNBGLM
¤
DifferentiableNBGLM(
config: NBGLMConfig,
*,
rngs: Rngs | None = None,
name: str | None = None,
)
Bases: OperatorModule
Differentiable Negative Binomial GLM for differential expression.
This operator implements a negative binomial GLM where: - log(mu) = X @ beta (design matrix @ coefficients) - P(count | mu, dispersion) = NB(count; mu, dispersion)
Gradients flow through both the coefficient (beta) and dispersion parameters, enabling end-to-end learning.
The negative binomial distribution is parameterized as: - mean = mu - variance = mu + mu^2 / dispersion
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
NBGLMConfig
|
NBGLMConfig with model parameters. |
required |
rngs
|
Rngs | None
|
Flax NNX random number generators. |
None
|
name
|
str | None
|
Optional operator name. |
None
|
Example
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
NBGLMConfig
|
NB GLM configuration. |
required |
rngs
|
Rngs | None
|
Random number generators for initialization. |
None
|
name
|
str | None
|
Optional operator name. |
None
|
apply
¤
apply(
data: PyTree,
state: PyTree,
metadata: dict[str, Any] | None,
random_params: Any = None,
stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]
Apply NB GLM to count data.
This method computes the log likelihood and predicted mean for a given sample.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
PyTree
|
Dictionary containing: - "counts": Gene counts (n_features,) - "design": Design matrix row (n_covariates,) - "size_factor": Library size factor (scalar) |
required |
state
|
PyTree
|
Element state (passed through unchanged) |
required |
metadata
|
dict[str, Any] | None
|
Element metadata (passed through unchanged) |
required |
random_params
|
Any
|
Not used (deterministic operator) |
None
|
stats
|
dict[str, Any] | None
|
Not used |
None
|
Returns:
| Type | Description |
|---|---|
tuple[PyTree, PyTree, dict[str, Any] | None]
|
Tuple of (transformed_data, state, metadata): - transformed_data contains:
|
NBGLMConfig¤
diffbio.operators.statistical.nb_glm.NBGLMConfig
dataclass
¤
Bases: OperatorConfig
Configuration for DifferentiableNBGLM.
Attributes:
| Name | Type | Description |
|---|---|---|
n_features |
int
|
Number of features (genes). |
n_covariates |
int
|
Number of covariates in design matrix. |
estimate_dispersion |
bool
|
Whether to estimate dispersion parameters. |
DifferentiableEMQuantifier¤
diffbio.operators.statistical.em_quantification.DifferentiableEMQuantifier
¤
DifferentiableEMQuantifier(
config: EMQuantifierConfig,
*,
rngs: Rngs | None = None,
name: str | None = None,
)
Bases: TemperatureOperator
Differentiable EM for transcript quantification.
This operator implements the EM algorithm for estimating transcript abundances from read-to-transcript compatibility data. The fixed number of iterations enables gradient flow through the entire quantification process.
Algorithm: 1. Initialize abundances (learnable prior) 2. E-step: Probabilistic assignment of reads to transcripts weights = softmax(compatibility * abundances / temperature) 3. M-step: Update abundances from weighted counts abundances = sum(weights) / effective_lengths abundances = abundances / sum(abundances) 4. Repeat for n_iterations
Inherits from TemperatureOperator to get:
- _temperature property for temperature-controlled smoothing
- soft_max() for logsumexp-based smooth maximum
- soft_argmax() for soft position selection
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
EMQuantifierConfig
|
EMQuantifierConfig with model parameters. |
required |
rngs
|
Rngs | None
|
Flax NNX random number generators. |
None
|
name
|
str | None
|
Optional operator name. |
None
|
Example
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
EMQuantifierConfig
|
EM quantifier configuration. |
required |
rngs
|
Rngs | None
|
Random number generators for initialization. |
None
|
name
|
str | None
|
Optional operator name. |
None
|
apply
¤
apply(
data: PyTree,
state: PyTree,
metadata: dict[str, Any] | None,
random_params: Any = None,
stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]
Apply EM quantification to read assignment data.
This method runs the EM algorithm to estimate transcript abundances from read-transcript compatibility data.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
PyTree
|
Dictionary containing: - "compatibility": Read-transcript compatibility matrix (n_reads, n_transcripts) - "effective_lengths": Effective transcript lengths (n_transcripts,) |
required |
state
|
PyTree
|
Element state (passed through unchanged) |
required |
metadata
|
dict[str, Any] | None
|
Element metadata (passed through unchanged) |
required |
random_params
|
Any
|
Not used (deterministic operator) |
None
|
stats
|
dict[str, Any] | None
|
Not used |
None
|
Returns:
| Type | Description |
|---|---|
tuple[PyTree, PyTree, dict[str, Any] | None]
|
Tuple of (transformed_data, state, metadata): - transformed_data contains:
|
EMQuantifierConfig¤
diffbio.operators.statistical.em_quantification.EMQuantifierConfig
dataclass
¤
Bases: OperatorConfig
Configuration for DifferentiableEMQuantifier.
Attributes:
| Name | Type | Description |
|---|---|---|
n_transcripts |
int
|
Number of transcripts to quantify. |
n_iterations |
int
|
Fixed number of EM iterations (for unrolling). |
temperature |
float
|
Temperature for softmax in E-step. |
Usage Examples¤
Hidden Markov Model¤
from flax import nnx
from diffbio.operators.statistical import DifferentiableHMM, HMMConfig
config = HMMConfig(num_states=5, num_emissions=4)
hmm = DifferentiableHMM(config, rngs=nnx.Rngs(42))
data = {"observations": observations} # (seq_length, n_observations)
result, _, _ = hmm.apply(data, {}, None)
log_likelihood = result["log_likelihood"]
posteriors = result["posteriors"]
Negative Binomial GLM¤
from diffbio.operators.statistical import DifferentiableNBGLM, NBGLMConfig
config = NBGLMConfig(n_features=2000, n_covariates=10)
nbglm = DifferentiableNBGLM(config, rngs=nnx.Rngs(42))
data = {
"counts": count_matrix,
"design": design_matrix,
"size_factors": size_factors,
}
result, _, _ = nbglm.apply(data, {}, None)
coefficients = result["coefficients"]
EM Quantifier¤
from diffbio.operators.statistical import DifferentiableEMQuantifier, EMQuantifierConfig
config = EMQuantifierConfig(n_transcripts=50000, n_iterations=100)
em_quant = DifferentiableEMQuantifier(config, rngs=nnx.Rngs(42))
data = {
"equivalence_classes": eq_classes,
"counts": eq_counts,
"effective_lengths": eff_lengths,
}
result, _, _ = em_quant.apply(data, {}, None)
tpm = result["tpm"]