Alignment Operators API¤
Advanced differentiable alignment operators for multiple sequence alignment and profile HMMs.
SoftProgressiveMSA¤
diffbio.operators.alignment.soft_msa.SoftProgressiveMSA
¤
SoftProgressiveMSA(
config: SoftProgressiveMSAConfig,
*,
rngs: Rngs | None = None,
name: str | None = None,
)
Bases: TemperatureOperator
Differentiable progressive multiple sequence alignment.
This operator performs multiple sequence alignment using soft operations that maintain gradient flow throughout the process.
Algorithm: 1. Encode each sequence to get embeddings 2. Compute pairwise distances from embeddings 3. Build soft guide tree from distances 4. Progressive alignment following guide tree order 5. Build consensus profile
Inherits from TemperatureOperator to get:
- _temperature property for temperature-controlled smoothing
- soft_max() for logsumexp-based smooth maximum
- soft_argmax() for soft position selection
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
SoftProgressiveMSAConfig
|
SoftProgressiveMSAConfig with model parameters. |
required |
rngs
|
Rngs | None
|
Flax NNX random number generators. |
None
|
name
|
str | None
|
Optional operator name. |
None
|
Example
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
SoftProgressiveMSAConfig
|
MSA configuration. |
required |
rngs
|
Rngs | None
|
Random number generators for initialization. |
None
|
name
|
str | None
|
Optional operator name. |
None
|
apply
¤
apply(
data: PyTree,
state: PyTree,
metadata: dict[str, Any] | None,
random_params: Any = None,
stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]
Apply soft progressive MSA.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
PyTree
|
Dictionary containing: - "sequences": Input sequences (n_seqs, seq_len, alphabet_size) |
required |
state
|
PyTree
|
Element state (passed through unchanged) |
required |
metadata
|
dict[str, Any] | None
|
Element metadata (passed through unchanged) |
required |
random_params
|
Any
|
Not used |
None
|
stats
|
dict[str, Any] | None
|
Not used |
None
|
Returns:
| Type | Description |
|---|---|
tuple[PyTree, PyTree, dict[str, Any] | None]
|
Tuple of (transformed_data, state, metadata): - transformed_data contains:
|
SoftProgressiveMSAConfig¤
diffbio.operators.alignment.soft_msa.SoftProgressiveMSAConfig
dataclass
¤
SoftProgressiveMSAConfig(
max_seq_length: int = 100,
hidden_dim: int = 64,
num_layers: int = 2,
alphabet_size: int = 4,
temperature: float = 1.0,
gap_open_penalty: float = -10.0,
gap_extend_penalty: float = -1.0,
)
Bases: OperatorConfig
Configuration for SoftProgressiveMSA.
Attributes:
| Name | Type | Description |
|---|---|---|
max_seq_length |
int
|
Maximum sequence length. |
hidden_dim |
int
|
Hidden dimension for neural networks. |
num_layers |
int
|
Number of encoder layers. |
alphabet_size |
int
|
Size of sequence alphabet (4 for DNA, 20 for protein). |
temperature |
float
|
Temperature for softmax operations. |
gap_open_penalty |
float
|
Gap opening penalty. |
gap_extend_penalty |
float
|
Gap extension penalty. |
ProfileHMMSearch¤
diffbio.operators.alignment.profile_hmm.ProfileHMMSearch
¤
ProfileHMMSearch(
config: ProfileHMMConfig,
*,
rngs: Rngs | None = None,
name: str | None = None,
)
Bases: TemperatureOperator
Profile HMM search with differentiable scoring.
This operator implements a simplified profile HMM with match, insert, and delete states. The forward algorithm computes the alignment score differentiably using logsumexp.
Profile HMM structure (per position): - Match state: emits according to position-specific distribution - Insert state: emits according to background distribution - Delete state: silent (no emission)
Transitions: - M->M, M->I, M->D (from match) - I->M, I->I (from insert) - D->M, D->D (from delete)
Inherits from TemperatureOperator to get:
- _temperature property for temperature-controlled smoothing
- soft_max() for logsumexp-based smooth maximum
- soft_argmax() for soft position selection
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
ProfileHMMConfig
|
ProfileHMMConfig with model parameters. |
required |
rngs
|
Rngs | None
|
Flax NNX random number generators. |
None
|
name
|
str | None
|
Optional operator name. |
None
|
Example
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
ProfileHMMConfig
|
Profile HMM configuration. |
required |
rngs
|
Rngs | None
|
Random number generators for initialization. |
None
|
name
|
str | None
|
Optional operator name. |
None
|
apply
¤
apply(
data: PyTree,
state: PyTree,
metadata: dict[str, Any] | None,
random_params: Any = None,
stats: dict[str, Any] | None = None,
) -> tuple[PyTree, PyTree, dict[str, Any] | None]
Apply profile HMM search to sequence.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
PyTree
|
Dictionary containing: - "sequence": One-hot encoded sequence (seq_len, alphabet_size) |
required |
state
|
PyTree
|
Element state (passed through unchanged) |
required |
metadata
|
dict[str, Any] | None
|
Element metadata (passed through unchanged) |
required |
random_params
|
Any
|
Not used (deterministic operator) |
None
|
stats
|
dict[str, Any] | None
|
Not used |
None
|
Returns:
| Type | Description |
|---|---|
tuple[PyTree, PyTree, dict[str, Any] | None]
|
Tuple of (transformed_data, state, metadata): - transformed_data contains:
|
ProfileHMMConfig¤
diffbio.operators.alignment.profile_hmm.ProfileHMMConfig
dataclass
¤
ProfileHMMConfig(
cacheable: bool = True,
profile_length: int = 100,
alphabet_size: int = 20,
temperature: float = 1.0,
learnable_profile: bool = True,
)
Bases: OperatorConfig
Configuration for ProfileHMMSearch.
Attributes:
| Name | Type | Description |
|---|---|---|
profile_length |
int
|
Length of the profile (number of match states). |
alphabet_size |
int
|
Size of sequence alphabet (20 for protein, 4 for DNA). |
temperature |
float
|
Temperature for softmax operations. |
learnable_profile |
bool
|
Whether profile parameters are learnable. |
Usage Examples¤
Multiple Sequence Alignment¤
from flax import nnx
from diffbio.operators.alignment import SoftProgressiveMSA, SoftProgressiveMSAConfig
config = SoftProgressiveMSAConfig(
max_seq_length=100,
hidden_dim=64,
alphabet_size=4,
)
msa = SoftProgressiveMSA(config, rngs=nnx.Rngs(42))
data = {"sequences": sequences} # (n_seqs, seq_len, alphabet_size)
result, _, _ = msa.apply(data, {}, None)