Skip to content

RNA Structure Operators API¤

Differentiable operators for RNA secondary structure prediction.

DifferentiableRNAFold¤

diffbio.operators.rna_structure.rna_folding.DifferentiableRNAFold ¤

DifferentiableRNAFold(
    config: RNAFoldConfig,
    *,
    rngs: Rngs,
    name: str | None = None,
)

Bases: TemperatureOperator

Differentiable RNA secondary structure prediction.

This operator computes base pair probabilities for RNA sequences using a McCaskill-style partition function algorithm. The implementation uses temperature-controlled softmax for full differentiability.

The McCaskill algorithm computes Z = Σ_P exp(-E(P)/RT), the partition function over all possible secondary structures. From this, base pair probabilities are derived as the marginal probability that positions i and j are paired in the ensemble.

For differentiability, we generalize the algorithm to operate on continuous probability distributions over nucleotides, following Matthies et al. (2024).

Input data structure
  • sequence: Float[Array, "length 4"] or Float[Array, "batch length 4"] One-hot encoded RNA sequence (A=0, C=1, G=2, U=3)

Output data structure (adds): - bp_probs: Float[Array, "length length"] - Base pair probabilities - partition_function: Float[Array, ""] - Log partition function

Example
config = RNAFoldConfig(temperature=1.0)
predictor = DifferentiableRNAFold(config, rngs=nnx.Rngs(42))
sequence = jax.nn.one_hot(seq_indices, num_classes=4)
result, _, _ = predictor.apply({"sequence": sequence}, {}, None)
bp_probs = result["bp_probs"]  # (length, length)

Parameters:

Name Type Description Default
config RNAFoldConfig

Configuration with folding parameters.

required
rngs Rngs

Random number generators.

required
name str | None

Optional operator name.

None

apply ¤

apply(
    data: PyTree,
    state: PyTree,
    metadata: dict[str, Any] | None,
    random_params: Any = None,
    stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]

Apply RNA folding prediction to sequence data.

Parameters:

Name Type Description Default
data PyTree

Dictionary containing: - "sequence": One-hot encoded RNA sequence Shape: (length, 4) or (batch, length, 4)

required
state PyTree

Element state (passed through unchanged)

required
metadata dict[str, Any] | None

Element metadata (passed through unchanged)

required
random_params Any

Not used

None
stats dict[str, Any] | None

Not used

None

Returns:

Type Description
tuple[PyTree, PyTree, dict[str, Any] | None]

Tuple of (transformed_data, state, metadata): - transformed_data contains:

- All original keys from data
- "bp_probs": Base pair probability matrix
- "partition_function": Log partition function
  • state is passed through unchanged
  • metadata is passed through unchanged

RNAFoldConfig¤

diffbio.operators.rna_structure.rna_folding.RNAFoldConfig dataclass ¤

RNAFoldConfig(
    alphabet_size: int = 4,
    bp_energy_au: float = BP_ENERGY_AU,
    bp_energy_gc: float = BP_ENERGY_GC,
    bp_energy_gu: float = BP_ENERGY_GU,
    learnable_temperature: bool = False,
    cacheable: bool = True,
    temperature: float = 1.0,
    min_hairpin_loop: int = DEFAULT_MIN_HAIRPIN,
)

Bases: _RNAFoldRuntimeConfig, _RNAFoldEnergyConfig, OperatorConfig

Configuration for DifferentiableRNAFold.

alphabet_size class-attribute instance-attribute ¤

alphabet_size: int = 4

bp_energy_au class-attribute instance-attribute ¤

bp_energy_au: float = BP_ENERGY_AU

bp_energy_gc class-attribute instance-attribute ¤

bp_energy_gc: float = BP_ENERGY_GC

bp_energy_gu class-attribute instance-attribute ¤

bp_energy_gu: float = BP_ENERGY_GU

learnable_temperature class-attribute instance-attribute ¤

learnable_temperature: bool = False

cacheable class-attribute instance-attribute ¤

cacheable: bool = True

temperature class-attribute instance-attribute ¤

temperature: float = 1.0

min_hairpin_loop class-attribute instance-attribute ¤

min_hairpin_loop: int = DEFAULT_MIN_HAIRPIN

Factory Function¤

create_rna_fold_predictor¤

diffbio.operators.rna_structure.rna_folding.create_rna_fold_predictor ¤

create_rna_fold_predictor(
    temperature: float = 1.0,
    min_hairpin_loop: int = DEFAULT_MIN_HAIRPIN,
    bp_energy_au: float = BP_ENERGY_AU,
    bp_energy_gc: float = BP_ENERGY_GC,
    bp_energy_gu: float = BP_ENERGY_GU,
    *,
    rngs: Rngs | None = None,
) -> DifferentiableRNAFold

Create an RNA fold predictor with given parameters.

Factory function for convenient predictor creation.

Parameters:

Name Type Description Default
temperature float

Softmax temperature for Boltzmann distribution.

1.0
min_hairpin_loop int

Minimum hairpin loop size.

DEFAULT_MIN_HAIRPIN
bp_energy_au float

Energy for A-U base pair.

BP_ENERGY_AU
bp_energy_gc float

Energy for G-C base pair.

BP_ENERGY_GC
bp_energy_gu float

Energy for G-U wobble pair.

BP_ENERGY_GU
rngs Rngs | None

Random number generators.

None

Returns:

Type Description
DifferentiableRNAFold

Configured DifferentiableRNAFold instance.

Example
predictor = create_rna_fold_predictor(temperature=0.5)
result, _, _ = predictor.apply({"sequence": seq}, {}, None)

Helper Functions¤

compute_pair_energy_matrix¤

diffbio.operators.rna_structure.rna_folding.compute_pair_energy_matrix ¤

compute_pair_energy_matrix(
    sequence: Float[Array, "length 4"],
    bp_energy_au: float = BP_ENERGY_AU,
    bp_energy_gc: float = BP_ENERGY_GC,
    bp_energy_gu: float = BP_ENERGY_GU,
) -> Float[Array, "length length"]

Compute base pair energy matrix for RNA sequence.

Uses Watson-Crick and wobble base pairing rules: - A-U: 2 hydrogen bonds (medium strength) - G-C: 3 hydrogen bonds (strongest) - G-U: Wobble pair (weakest)

For soft/probabilistic sequences, the energy is weighted by the probability of each nucleotide at each position.

Parameters:

Name Type Description Default
sequence Float[Array, 'length 4']

One-hot encoded RNA sequence (A=0, C=1, G=2, U=3).

required
bp_energy_au float

Energy for A-U pair.

BP_ENERGY_AU
bp_energy_gc float

Energy for G-C pair.

BP_ENERGY_GC
bp_energy_gu float

Energy for G-U wobble pair.

BP_ENERGY_GU

Returns:

Type Description
Float[Array, 'length length']

Energy matrix where [i,j] is the pairing energy for positions i,j.

Float[Array, 'length length']

More negative = more favorable pairing.

compute_base_pair_probabilities¤

diffbio.operators.rna_structure.rna_folding.compute_base_pair_probabilities ¤

compute_base_pair_probabilities(
    energy_matrix: Float[Array, "length length"],
    min_hairpin: int = DEFAULT_MIN_HAIRPIN,
    temperature: Array | float = 1.0,
) -> tuple[Float[Array, "length length"], Float[Array, ""]]

Compute base pair probability matrix.

Uses a simplified approach where probabilities are derived from the Boltzmann-weighted base pair energies normalized over all valid positions.

For full McCaskill, one would compute inside-outside probabilities, but for differentiability and simplicity, we use: P[i,j] ∝ exp(-E[i,j]/T) * validity_mask[i,j]

Parameters:

Name Type Description Default
energy_matrix Float[Array, 'length length']

Base pair energy matrix.

required
min_hairpin int

Minimum hairpin loop size.

DEFAULT_MIN_HAIRPIN
temperature Array | float

Temperature for Boltzmann distribution.

1.0

Returns:

Type Description
Float[Array, 'length length']

Tuple of (bp_probs, log_Z) where:

Float[Array, '']
  • bp_probs[i,j] is probability that i and j are paired
tuple[Float[Array, 'length length'], Float[Array, '']]
  • log_Z is log partition function (logsumexp of valid pairs)

Usage Examples¤

Basic Usage¤

from diffbio.operators.rna_structure import create_rna_fold_predictor
import jax
import jax.numpy as jnp

# Create predictor
predictor = create_rna_fold_predictor(temperature=1.0)

# Prepare one-hot encoded RNA sequence
sequence = jax.nn.one_hot(
    jax.random.randint(jax.random.PRNGKey(0), (50,), 0, 4),
    num_classes=4,
)

# Apply
result, _, _ = predictor.apply({"sequence": sequence}, {}, None)
bp_probs = result["bp_probs"]  # (50, 50)

Full Configuration¤

from diffbio.operators.rna_structure import (
    DifferentiableRNAFold,
    RNAFoldConfig,
)
from flax import nnx

config = RNAFoldConfig(
    temperature=0.5,       # Sharper predictions
    min_hairpin_loop=3,    # Standard hairpin constraint
    bp_energy_au=-2.0,     # A-U pair energy
    bp_energy_gc=-3.0,     # G-C pair energy
    bp_energy_gu=-1.0,     # G-U wobble energy
)

predictor = DifferentiableRNAFold(config, rngs=nnx.Rngs(42))

Batched Processing¤

import jax.numpy as jnp

# Batch of RNA sequences
batch_size = 8
seq_len = 50
sequences = jax.nn.one_hot(
    jax.random.randint(jax.random.PRNGKey(0), (batch_size, seq_len), 0, 4),
    num_classes=4,
)

result, _, _ = predictor.apply({"sequence": sequences}, {}, None)
bp_probs = result["bp_probs"]  # (8, 50, 50)

Gradient Computation¤

import jax
from flax import nnx

predictor = create_rna_fold_predictor()

def loss_fn(model, sequence):
    result, _, _ = model.apply({"sequence": sequence}, {}, None)
    # Example: maximize specific base pair probability
    return -result["bp_probs"][10, 40]

# Compute gradients w.r.t. model parameters
_, grads = nnx.value_and_grad(loss_fn)(predictor, sequence)

Input Specifications¤

Key Shape Type Description
sequence (length, 4) or (batch, length, 4) float32 One-hot encoded RNA (A=0, C=1, G=2, U=3)

Output Specifications¤

Key Shape Type Description
sequence same as input float32 Original input sequence
bp_probs (length, length) or (batch, length, length) float32 Base pair probability matrix
partition_function () or (batch,) float32 Log partition function

Base Pair Energies¤

Pair Default Energy Description
A-U -2.0 Watson-Crick (2 H-bonds)
G-C -3.0 Watson-Crick (3 H-bonds)
G-U -1.0 Wobble pair