Protein Structure Operators API¤
Differentiable operators for protein structure analysis.
DifferentiableSecondaryStructure¤
diffbio.operators.protein.secondary_structure.DifferentiableSecondaryStructure
¤
DifferentiableSecondaryStructure(
config: SecondaryStructureConfig,
*,
rngs: Rngs,
name: str | None = None,
)
Bases: OperatorModule
Differentiable secondary structure prediction using DSSP algorithm.
This operator computes secondary structure assignments for protein backbone atoms using a differentiable version of the DSSP algorithm. The key innovation is a continuous hydrogen bond matrix that enables gradient flow through the secondary structure prediction.
The algorithm: 1. Compute hydrogen bond energies using Kabsch-Sander electrostatic formula 2. Apply smooth transformation to create continuous H-bond matrix in [0,1] 3. Detect helix patterns (i→i+4 H-bonds) and strand patterns 4. Output soft secondary structure assignments
Input data structure
- coordinates: Float[Array, "batch length 4 3"] - Backbone atoms (N, CA, C, O)
Output data structure (adds): - ss_onehot: Float[Array, "batch length 3"] - Soft SS probabilities - hbond_map: Float[Array, "batch length length"] - Continuous H-bond matrix - ss_indices: Int[Array, "batch length"] - Hard SS assignments (0=loop, 1=helix, 2=strand)
Example
config = SecondaryStructureConfig(margin=1.0, cutoff=-0.5)
predictor = DifferentiableSecondaryStructure(config, rngs=nnx.Rngs(42))
coords = jax.random.uniform(key, (1, 50, 4, 3)) * 10 # 50 residues
result, _, _ = predictor.apply({"coordinates": coords}, {}, None)
ss_probs = result["ss_onehot"] # (1, 50, 3)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
SecondaryStructureConfig
|
Configuration with DSSP parameters. |
required |
rngs
|
Rngs
|
Random number generators. |
required |
name
|
str | None
|
Optional name for the operator. |
None
|
apply
¤
apply(
data: dict[str, Any],
state: dict[str, Any],
metadata: dict[str, Any] | None,
random_params: Any = None,
stats: dict[str, Any] | None = None,
) -> tuple[
dict[str, Any], dict[str, Any], dict[str, Any] | None
]
Apply secondary structure prediction.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
dict[str, Any]
|
Input data containing: - coordinates: Float[Array, "batch length 4 3"] |
required |
state
|
dict[str, Any]
|
Element state (passed through). |
required |
metadata
|
dict[str, Any] | None
|
Element metadata (passed through). |
required |
random_params
|
Any
|
Random parameters (unused). |
None
|
stats
|
dict[str, Any] | None
|
Optional statistics (unused). |
None
|
Returns:
| Type | Description |
|---|---|
tuple[dict[str, Any], dict[str, Any], dict[str, Any] | None]
|
Tuple of (output_data, state, metadata). |
compute_hbond_energy
¤
compute_hbond_energy(
coords: Float[Array, "batch length 4 3"],
) -> Float[Array, "batch length length"]
Compute hydrogen bond energy matrix.
Uses the Kabsch-Sander electrostatic energy formula
E = q1*q2 * f * (1/r_ON + 1/r_CH - 1/r_OH - 1/r_CN)
where r_XY is the distance between atoms X and Y.
Donor: N-H from residue i Acceptor: C=O from residue j
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
coords
|
Float[Array, 'batch length 4 3']
|
Backbone coordinates (batch, length, 4, 3). Atom order: N, CA, C, O |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, 'batch length length']
|
Energy matrix (batch, length, length) in kcal/mol. |
Float[Array, 'batch length length']
|
E[b, i, j] = energy of H-bond from donor i to acceptor j. |
compute_hbond_map
¤
compute_hbond_map(
coords: Float[Array, "batch length 4 3"],
) -> Float[Array, "batch length length"]
Compute continuous hydrogen bond matrix.
Transforms the energy matrix into a continuous [0,1] matrix using a smooth sigmoid-like function based on sine: HbondMat(i,j) = (1 + sin((cutoff - E - margin) / margin * pi/2)) / 2
This allows gradients to flow through the H-bond detection.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
coords
|
Float[Array, 'batch length 4 3']
|
Backbone coordinates (batch, length, 4, 3). |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, 'batch length length']
|
Continuous H-bond matrix (batch, length, length) in [0, 1]. |
detect_helix_pattern
¤
detect_helix_pattern(
hbond_map: Float[Array, "batch length length"],
) -> Float[Array, "batch length"]
Detect alpha-helix pattern (i→i+4 hydrogen bonds).
Alpha-helices are characterized by H-bonds from residue i (donor) to residue i-4 (acceptor), creating i→i+4 backbone H-bonds.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hbond_map
|
Float[Array, 'batch length length']
|
Continuous H-bond matrix. |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, 'batch length']
|
Soft helix assignment for each residue. |
detect_strand_pattern
¤
detect_strand_pattern(
hbond_map: Float[Array, "batch length length"],
) -> Float[Array, "batch length"]
Detect beta-strand pattern (parallel/antiparallel H-bonds).
Beta-strands are characterized by H-bonds between distant residues forming ladder-like patterns (parallel or antiparallel).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hbond_map
|
Float[Array, 'batch length length']
|
Continuous H-bond matrix. |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, 'batch length']
|
Soft strand assignment for each residue. |
assign_secondary_structure
¤
assign_secondary_structure(
hbond_map: Float[Array, "batch length length"],
) -> Float[Array, "batch length 3"]
Assign secondary structure based on H-bond patterns.
Combines helix and strand detection into soft assignments.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hbond_map
|
Float[Array, 'batch length length']
|
Continuous H-bond matrix. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
Float[Array, 'batch length 3']
|
One-hot encoded SS assignments (batch, length, 3). |
|
Classes |
Float[Array, 'batch length 3']
|
0=loop, 1=helix, 2=strand |
SecondaryStructureConfig¤
diffbio.operators.protein.secondary_structure.SecondaryStructureConfig
dataclass
¤
SecondaryStructureConfig(
use_bond_length_constraint: bool = False,
use_bond_angle_constraint: bool = False,
bond_length_weight: float = 0.1,
bond_angle_weight: float = 0.1,
margin: float = DEFAULT_MARGIN,
cutoff: float = DEFAULT_CUTOFF,
min_helix_length: int = 4,
temperature: float = 1.0,
)
Factory Function¤
create_secondary_structure_predictor¤
diffbio.operators.protein.secondary_structure.create_secondary_structure_predictor
¤
create_secondary_structure_predictor(
margin: float = DEFAULT_MARGIN,
cutoff: float = DEFAULT_CUTOFF,
min_helix_length: int = 4,
temperature: float = 1.0,
seed: int = 42,
) -> DifferentiableSecondaryStructure
Factory function to create a secondary structure predictor.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
margin
|
float
|
Smoothing margin for H-bond matrix. Default 1.0. |
DEFAULT_MARGIN
|
cutoff
|
float
|
H-bond energy threshold in kcal/mol. Default -0.5. |
DEFAULT_CUTOFF
|
min_helix_length
|
int
|
Minimum residues for helix. Default 4. |
4
|
temperature
|
float
|
Softmax temperature. Default 1.0. |
1.0
|
seed
|
int
|
Random seed. Default 42. |
42
|
Returns:
| Type | Description |
|---|---|
DifferentiableSecondaryStructure
|
Configured DifferentiableSecondaryStructure instance. |
Helper Functions¤
compute_hydrogen_position¤
diffbio.operators.protein.secondary_structure.compute_hydrogen_position
¤
compute_hydrogen_position(
n_pos: Float[Array, "... 3"],
ca_pos: Float[Array, "... 3"],
c_prev_pos: Float[Array, "... 3"],
) -> Float[Array, "... 3"]
Compute hydrogen atom position from backbone atoms.
The amide hydrogen is placed along the N-H bond direction, which is approximately opposite to the bisector of CA-N and C_prev-N vectors.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_pos
|
Float[Array, '... 3']
|
Nitrogen atom positions. |
required |
ca_pos
|
Float[Array, '... 3']
|
Alpha carbon positions. |
required |
c_prev_pos
|
Float[Array, '... 3']
|
Carbonyl carbon from previous residue. |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, '... 3']
|
Estimated hydrogen atom positions. |
Usage Examples¤
Basic Usage¤
from diffbio.operators.protein import create_secondary_structure_predictor
import jax
import jax.numpy as jnp
# Create predictor
predictor = create_secondary_structure_predictor()
# Prepare coordinates (batch, n_residues, 4 atoms, xyz)
coords = jax.random.uniform(jax.random.PRNGKey(0), (1, 50, 4, 3)) * 10
# Apply
result, _, _ = predictor.apply({"coordinates": coords}, {}, None)
ss_probs = result["ss_onehot"] # (1, 50, 3)
Full Configuration¤
from diffbio.operators.protein import (
DifferentiableSecondaryStructure,
SecondaryStructureConfig,
)
from flax import nnx
config = SecondaryStructureConfig(
margin=1.0,
cutoff=-0.5,
min_helix_length=4,
temperature=0.5, # Sharper assignments
)
predictor = DifferentiableSecondaryStructure(config, rngs=nnx.Rngs(42))
Gradient Computation¤
import jax
def loss_fn(coords):
result, _, _ = predictor.apply({"coordinates": coords}, {}, None)
# Maximize helix content
helix_prob = result["ss_onehot"][:, :, 1] # Index 1 = helix
return -helix_prob.mean()
# Compute gradients w.r.t. coordinates
grads = jax.grad(loss_fn)(coords)
Input Specifications¤
| Key | Shape | Description |
|---|---|---|
coordinates |
(batch, length, 4, 3) | Backbone atoms (N, CA, C, O) in Angstroms |
Output Specifications¤
| Key | Shape | Description |
|---|---|---|
coordinates |
(batch, length, 4, 3) | Original input coordinates |
ss_onehot |
(batch, length, 3) | Soft SS probabilities |
ss_indices |
(batch, length) | Hard SS assignments |
hbond_map |
(batch, length, length) | Continuous H-bond matrix |
Constants¤
| Constant | Value | Description |
|---|---|---|
CONST_Q1Q2 |
0.084 | Partial charge product |
CONST_F |
332.0 | Conversion to kcal/mol |
DEFAULT_CUTOFF |
-0.5 | H-bond energy threshold |
DEFAULT_MARGIN |
1.0 | Smoothing margin |