RNA Structure Operators API¤
Differentiable operators for RNA secondary structure prediction.
DifferentiableRNAFold¤
diffbio.operators.rna_structure.rna_folding.DifferentiableRNAFold
¤
DifferentiableRNAFold(
config: RNAFoldConfig,
*,
rngs: Rngs,
name: str | None = None,
)
Bases: TemperatureOperator
Differentiable RNA secondary structure prediction.
This operator computes base pair probabilities for RNA sequences using a McCaskill-style partition function algorithm. The implementation uses temperature-controlled softmax for full differentiability.
The McCaskill algorithm computes Z = Σ_P exp(-E(P)/RT), the partition function over all possible secondary structures. From this, base pair probabilities are derived as the marginal probability that positions i and j are paired in the ensemble.
For differentiability, we generalize the algorithm to operate on continuous probability distributions over nucleotides, following Matthies et al. (2024).
Input data structure
- sequence: Float[Array, "length 4"] or Float[Array, "batch length 4"] One-hot encoded RNA sequence (A=0, C=1, G=2, U=3)
Output data structure (adds): - bp_probs: Float[Array, "length length"] - Base pair probabilities - partition_function: Float[Array, ""] - Log partition function
Example
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
RNAFoldConfig
|
Configuration with folding parameters. |
required |
rngs
|
Rngs
|
Random number generators. |
required |
name
|
str | None
|
Optional operator name. |
None
|
apply
¤
apply(
data: PyTree,
state: PyTree,
metadata: dict[str, Any] | None,
random_params: Any = None,
stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]
Apply RNA folding prediction to sequence data.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
PyTree
|
Dictionary containing: - "sequence": One-hot encoded RNA sequence Shape: (length, 4) or (batch, length, 4) |
required |
state
|
PyTree
|
Element state (passed through unchanged) |
required |
metadata
|
dict[str, Any] | None
|
Element metadata (passed through unchanged) |
required |
random_params
|
Any
|
Not used |
None
|
stats
|
dict[str, Any] | None
|
Not used |
None
|
Returns:
| Type | Description |
|---|---|
tuple[PyTree, PyTree, dict[str, Any] | None]
|
Tuple of (transformed_data, state, metadata): - transformed_data contains:
|
RNAFoldConfig¤
diffbio.operators.rna_structure.rna_folding.RNAFoldConfig
dataclass
¤
RNAFoldConfig(
alphabet_size: int = 4,
bp_energy_au: float = BP_ENERGY_AU,
bp_energy_gc: float = BP_ENERGY_GC,
bp_energy_gu: float = BP_ENERGY_GU,
learnable_temperature: bool = False,
cacheable: bool = True,
temperature: float = 1.0,
min_hairpin_loop: int = DEFAULT_MIN_HAIRPIN,
)
Bases: _RNAFoldRuntimeConfig, _RNAFoldEnergyConfig, OperatorConfig
Configuration for DifferentiableRNAFold.
Factory Function¤
create_rna_fold_predictor¤
diffbio.operators.rna_structure.rna_folding.create_rna_fold_predictor
¤
create_rna_fold_predictor(
temperature: float = 1.0,
min_hairpin_loop: int = DEFAULT_MIN_HAIRPIN,
bp_energy_au: float = BP_ENERGY_AU,
bp_energy_gc: float = BP_ENERGY_GC,
bp_energy_gu: float = BP_ENERGY_GU,
*,
rngs: Rngs | None = None,
) -> DifferentiableRNAFold
Create an RNA fold predictor with given parameters.
Factory function for convenient predictor creation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
temperature
|
float
|
Softmax temperature for Boltzmann distribution. |
1.0
|
min_hairpin_loop
|
int
|
Minimum hairpin loop size. |
DEFAULT_MIN_HAIRPIN
|
bp_energy_au
|
float
|
Energy for A-U base pair. |
BP_ENERGY_AU
|
bp_energy_gc
|
float
|
Energy for G-C base pair. |
BP_ENERGY_GC
|
bp_energy_gu
|
float
|
Energy for G-U wobble pair. |
BP_ENERGY_GU
|
rngs
|
Rngs | None
|
Random number generators. |
None
|
Returns:
| Type | Description |
|---|---|
DifferentiableRNAFold
|
Configured DifferentiableRNAFold instance. |
Helper Functions¤
compute_pair_energy_matrix¤
diffbio.operators.rna_structure.rna_folding.compute_pair_energy_matrix
¤
compute_pair_energy_matrix(
sequence: Float[Array, "length 4"],
bp_energy_au: float = BP_ENERGY_AU,
bp_energy_gc: float = BP_ENERGY_GC,
bp_energy_gu: float = BP_ENERGY_GU,
) -> Float[Array, "length length"]
Compute base pair energy matrix for RNA sequence.
Uses Watson-Crick and wobble base pairing rules: - A-U: 2 hydrogen bonds (medium strength) - G-C: 3 hydrogen bonds (strongest) - G-U: Wobble pair (weakest)
For soft/probabilistic sequences, the energy is weighted by the probability of each nucleotide at each position.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sequence
|
Float[Array, 'length 4']
|
One-hot encoded RNA sequence (A=0, C=1, G=2, U=3). |
required |
bp_energy_au
|
float
|
Energy for A-U pair. |
BP_ENERGY_AU
|
bp_energy_gc
|
float
|
Energy for G-C pair. |
BP_ENERGY_GC
|
bp_energy_gu
|
float
|
Energy for G-U wobble pair. |
BP_ENERGY_GU
|
Returns:
| Type | Description |
|---|---|
Float[Array, 'length length']
|
Energy matrix where [i,j] is the pairing energy for positions i,j. |
Float[Array, 'length length']
|
More negative = more favorable pairing. |
compute_base_pair_probabilities¤
diffbio.operators.rna_structure.rna_folding.compute_base_pair_probabilities
¤
compute_base_pair_probabilities(
energy_matrix: Float[Array, "length length"],
min_hairpin: int = DEFAULT_MIN_HAIRPIN,
temperature: Array | float = 1.0,
) -> tuple[Float[Array, "length length"], Float[Array, ""]]
Compute base pair probability matrix.
Uses a simplified approach where probabilities are derived from the Boltzmann-weighted base pair energies normalized over all valid positions.
For full McCaskill, one would compute inside-outside probabilities, but for differentiability and simplicity, we use: P[i,j] ∝ exp(-E[i,j]/T) * validity_mask[i,j]
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
energy_matrix
|
Float[Array, 'length length']
|
Base pair energy matrix. |
required |
min_hairpin
|
int
|
Minimum hairpin loop size. |
DEFAULT_MIN_HAIRPIN
|
temperature
|
Array | float
|
Temperature for Boltzmann distribution. |
1.0
|
Returns:
| Type | Description |
|---|---|
Float[Array, 'length length']
|
Tuple of (bp_probs, log_Z) where: |
Float[Array, '']
|
|
tuple[Float[Array, 'length length'], Float[Array, '']]
|
|
Usage Examples¤
Basic Usage¤
from diffbio.operators.rna_structure import create_rna_fold_predictor
import jax
import jax.numpy as jnp
# Create predictor
predictor = create_rna_fold_predictor(temperature=1.0)
# Prepare one-hot encoded RNA sequence
sequence = jax.nn.one_hot(
jax.random.randint(jax.random.PRNGKey(0), (50,), 0, 4),
num_classes=4,
)
# Apply
result, _, _ = predictor.apply({"sequence": sequence}, {}, None)
bp_probs = result["bp_probs"] # (50, 50)
Full Configuration¤
from diffbio.operators.rna_structure import (
DifferentiableRNAFold,
RNAFoldConfig,
)
from flax import nnx
config = RNAFoldConfig(
temperature=0.5, # Sharper predictions
min_hairpin_loop=3, # Standard hairpin constraint
bp_energy_au=-2.0, # A-U pair energy
bp_energy_gc=-3.0, # G-C pair energy
bp_energy_gu=-1.0, # G-U wobble energy
)
predictor = DifferentiableRNAFold(config, rngs=nnx.Rngs(42))
Batched Processing¤
import jax.numpy as jnp
# Batch of RNA sequences
batch_size = 8
seq_len = 50
sequences = jax.nn.one_hot(
jax.random.randint(jax.random.PRNGKey(0), (batch_size, seq_len), 0, 4),
num_classes=4,
)
result, _, _ = predictor.apply({"sequence": sequences}, {}, None)
bp_probs = result["bp_probs"] # (8, 50, 50)
Gradient Computation¤
import jax
from flax import nnx
predictor = create_rna_fold_predictor()
def loss_fn(model, sequence):
result, _, _ = model.apply({"sequence": sequence}, {}, None)
# Example: maximize specific base pair probability
return -result["bp_probs"][10, 40]
# Compute gradients w.r.t. model parameters
_, grads = nnx.value_and_grad(loss_fn)(predictor, sequence)
Input Specifications¤
| Key | Shape | Type | Description |
|---|---|---|---|
sequence |
(length, 4) or (batch, length, 4) | float32 | One-hot encoded RNA (A=0, C=1, G=2, U=3) |
Output Specifications¤
| Key | Shape | Type | Description |
|---|---|---|---|
sequence |
same as input | float32 | Original input sequence |
bp_probs |
(length, length) or (batch, length, length) | float32 | Base pair probability matrix |
partition_function |
() or (batch,) | float32 | Log partition function |
Base Pair Energies¤
| Pair | Default Energy | Description |
|---|---|---|
| A-U | -2.0 | Watson-Crick (2 H-bonds) |
| G-C | -3.0 | Watson-Crick (3 H-bonds) |
| G-U | -1.0 | Wobble pair |