Skip to content

Alignment Operators API¤

Advanced differentiable alignment operators for multiple sequence alignment and profile HMMs.

SoftProgressiveMSA¤

diffbio.operators.alignment.soft_msa.SoftProgressiveMSA ¤

SoftProgressiveMSA(
    config: SoftProgressiveMSAConfig,
    *,
    rngs: Rngs | None = None,
    name: str | None = None,
)

Bases: TemperatureOperator

Differentiable progressive multiple sequence alignment.

This operator performs multiple sequence alignment using soft operations that maintain gradient flow throughout the process.

Algorithm: 1. Encode each sequence to get embeddings 2. Compute pairwise distances from embeddings 3. Build soft guide tree from distances 4. Progressive alignment following guide tree order 5. Build consensus profile

Inherits from TemperatureOperator to get:

  • _temperature property for temperature-controlled smoothing
  • soft_max() for logsumexp-based smooth maximum
  • soft_argmax() for soft position selection

Parameters:

Name Type Description Default
config SoftProgressiveMSAConfig

SoftProgressiveMSAConfig with model parameters.

required
rngs Rngs | None

Flax NNX random number generators.

None
name str | None

Optional operator name.

None
Example
config = SoftProgressiveMSAConfig(max_seq_length=100)
msa = SoftProgressiveMSA(config, rngs=nnx.Rngs(42))
data = {"sequences": seqs}  # (n_seqs, seq_len, alphabet_size)
result, state, meta = msa.apply(data, {}, None)

Parameters:

Name Type Description Default
config SoftProgressiveMSAConfig

MSA configuration.

required
rngs Rngs | None

Random number generators for initialization.

None
name str | None

Optional operator name.

None

apply ¤

apply(
    data: PyTree,
    state: PyTree,
    metadata: dict[str, Any] | None,
    random_params: Any = None,
    stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]

Apply soft progressive MSA.

Parameters:

Name Type Description Default
data PyTree

Dictionary containing: - "sequences": Input sequences (n_seqs, seq_len, alphabet_size)

required
state PyTree

Element state (passed through unchanged)

required
metadata dict[str, Any] | None

Element metadata (passed through unchanged)

required
random_params Any

Not used

None
stats dict[str, Any] | None

Not used

None

Returns:

Type Description
tuple[PyTree, PyTree, dict[str, Any] | None]

Tuple of (transformed_data, state, metadata): - transformed_data contains:

- "sequences": Original sequences
- "aligned_sequences": Soft-aligned sequences
- "pairwise_distances": Guide tree distances
- "alignment_scores": Pairwise alignment scores
- "consensus_profile": Consensus profile
  • state is passed through unchanged
  • metadata is passed through unchanged

SoftProgressiveMSAConfig¤

diffbio.operators.alignment.soft_msa.SoftProgressiveMSAConfig dataclass ¤

SoftProgressiveMSAConfig(
    max_seq_length: int = 100,
    hidden_dim: int = 64,
    num_layers: int = 2,
    alphabet_size: int = 4,
    temperature: float = 1.0,
    gap_open_penalty: float = -10.0,
    gap_extend_penalty: float = -1.0,
)

Bases: OperatorConfig

Configuration for SoftProgressiveMSA.

Attributes:

Name Type Description
max_seq_length int

Maximum sequence length.

hidden_dim int

Hidden dimension for neural networks.

num_layers int

Number of encoder layers.

alphabet_size int

Size of sequence alphabet (4 for DNA, 20 for protein).

temperature float

Temperature for softmax operations.

gap_open_penalty float

Gap opening penalty.

gap_extend_penalty float

Gap extension penalty.

ProfileHMMSearch¤

diffbio.operators.alignment.profile_hmm.ProfileHMMSearch ¤

ProfileHMMSearch(
    config: ProfileHMMConfig,
    *,
    rngs: Rngs | None = None,
    name: str | None = None,
)

Bases: TemperatureOperator

Profile HMM search with differentiable scoring.

This operator implements a simplified profile HMM with match, insert, and delete states. The forward algorithm computes the alignment score differentiably using logsumexp.

Profile HMM structure (per position): - Match state: emits according to position-specific distribution - Insert state: emits according to background distribution - Delete state: silent (no emission)

Transitions: - M->M, M->I, M->D (from match) - I->M, I->I (from insert) - D->M, D->D (from delete)

Inherits from TemperatureOperator to get:

  • _temperature property for temperature-controlled smoothing
  • soft_max() for logsumexp-based smooth maximum
  • soft_argmax() for soft position selection

Parameters:

Name Type Description Default
config ProfileHMMConfig

ProfileHMMConfig with model parameters.

required
rngs Rngs | None

Flax NNX random number generators.

None
name str | None

Optional operator name.

None
Example
config = ProfileHMMConfig(profile_length=100, alphabet_size=20)
profiler = ProfileHMMSearch(config, rngs=nnx.Rngs(42))
data = {"sequence": one_hot_sequence}
result, state, meta = profiler.apply(data, {}, None)

Parameters:

Name Type Description Default
config ProfileHMMConfig

Profile HMM configuration.

required
rngs Rngs | None

Random number generators for initialization.

None
name str | None

Optional operator name.

None

apply ¤

apply(
    data: PyTree,
    state: PyTree,
    metadata: dict[str, Any] | None,
    random_params: Any = None,
    stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]

Apply profile HMM search to sequence.

Parameters:

Name Type Description Default
data PyTree

Dictionary containing: - "sequence": One-hot encoded sequence (seq_len, alphabet_size)

required
state PyTree

Element state (passed through unchanged)

required
metadata dict[str, Any] | None

Element metadata (passed through unchanged)

required
random_params Any

Not used (deterministic operator)

None
stats dict[str, Any] | None

Not used

None

Returns:

Type Description
tuple[PyTree, PyTree, dict[str, Any] | None]

Tuple of (transformed_data, state, metadata): - transformed_data contains:

- "sequence": Original sequence
- "score": Profile alignment score
- "state_posteriors": Soft state assignments
  • state is passed through unchanged
  • metadata is passed through unchanged

ProfileHMMConfig¤

diffbio.operators.alignment.profile_hmm.ProfileHMMConfig dataclass ¤

ProfileHMMConfig(
    cacheable: bool = True,
    profile_length: int = 100,
    alphabet_size: int = 20,
    temperature: float = 1.0,
    learnable_profile: bool = True,
)

Bases: OperatorConfig

Configuration for ProfileHMMSearch.

Attributes:

Name Type Description
profile_length int

Length of the profile (number of match states).

alphabet_size int

Size of sequence alphabet (20 for protein, 4 for DNA).

temperature float

Temperature for softmax operations.

learnable_profile bool

Whether profile parameters are learnable.

Usage Examples¤

Multiple Sequence Alignment¤

from flax import nnx
from diffbio.operators.alignment import SoftProgressiveMSA, SoftProgressiveMSAConfig

config = SoftProgressiveMSAConfig(
    max_seq_length=100,
    hidden_dim=64,
    alphabet_size=4,
)
msa = SoftProgressiveMSA(config, rngs=nnx.Rngs(42))

data = {"sequences": sequences}  # (n_seqs, seq_len, alphabet_size)
result, _, _ = msa.apply(data, {}, None)

Profile HMM Scoring¤

from diffbio.operators.alignment import ProfileHMMSearch, ProfileHMMConfig

config = ProfileHMMConfig(profile_length=50, alphabet_size=4)
hmm = ProfileHMMSearch(config, rngs=nnx.Rngs(42))

data = {"sequence": sequence}
result, _, _ = hmm.apply(data, {}, None)
score = result["log_likelihood"]