Skip to content

Alignment Losses API¤

Loss functions for sequence alignment optimization.

AlignmentScoreLoss¤

diffbio.losses.alignment_losses.AlignmentScoreLoss ¤

AlignmentScoreLoss(*, rngs: Rngs | None = None)

Bases: Module

Loss function based on alignment quality score.

Computes a loss that measures how well the alignment captures the similarity between two sequences. Lower loss indicates better alignment of similar positions.

The loss computes the weighted sum of position-wise mismatches, where weights come from the soft alignment matrix.

Parameters:

Name Type Description Default
rngs Rngs | None

Flax NNX random number generators.

None

Parameters:

Name Type Description Default
rngs Rngs | None

Random number generators (optional).

None

__call__ ¤

__call__(
    seq1: Float[Array, "len1 alphabet"],
    seq2: Float[Array, "len2 alphabet"],
    alignment: Float[Array, "len1 len2"],
) -> Float[Array, ""]

Compute alignment score loss.

Parameters:

Name Type Description Default
seq1 Float[Array, 'len1 alphabet']

First sequence, soft one-hot encoded (len1, alphabet).

required
seq2 Float[Array, 'len2 alphabet']

Second sequence, soft one-hot encoded (len2, alphabet).

required
alignment Float[Array, 'len1 len2']

Soft alignment matrix where alignment[i,j] indicates probability of aligning position i to position j.

required

Returns:

Type Description
Float[Array, '']

Scalar loss value. Lower is better alignment.

SoftEditDistanceLoss¤

diffbio.losses.alignment_losses.SoftEditDistanceLoss ¤

SoftEditDistanceLoss(
    normalize: bool = False,
    temperature: float = 0.1,
    *,
    rngs: Rngs | None = None,
)

Bases: Module

Differentiable approximation of edit distance.

Computes a soft version of edit distance between two sequences that allows gradient flow. Uses the relationship between edit distance and alignment scores.

The edit distance is approximated as the complement of the optimal alignment score, scaled appropriately.

Parameters:

Name Type Description Default
normalize bool

Whether to normalize by sequence length.

False
temperature float

Temperature for soft minimum operations.

0.1
rngs Rngs | None

Flax NNX random number generators.

None

Parameters:

Name Type Description Default
normalize bool

Whether to normalize by sequence length.

False
temperature float

Temperature for softmax operations. Lower values give sharper approximation of true edit distance. Default 0.1 works well for one-hot encoded sequences.

0.1
rngs Rngs | None

Random number generators (optional).

None

__call__ ¤

__call__(
    seq1: Float[Array, "len1 alphabet"],
    seq2: Float[Array, "len2 alphabet"],
) -> Float[Array, ""]

Compute soft edit distance between sequences.

Parameters:

Name Type Description Default
seq1 Float[Array, 'len1 alphabet']

First sequence, soft one-hot encoded (len1, alphabet).

required
seq2 Float[Array, 'len2 alphabet']

Second sequence, soft one-hot encoded (len2, alphabet).

required

Returns:

Type Description
Float[Array, '']

Scalar soft edit distance. 0 for identical sequences.

AlignmentConsistencyLoss¤

diffbio.losses.alignment_losses.AlignmentConsistencyLoss ¤

AlignmentConsistencyLoss(*, rngs: Rngs | None = None)

Bases: Module

Loss for enforcing transitivity in multi-sequence alignments.

For three sequences A, B, C with pairwise alignments: - A->B (align_ab) - B->C (align_bc) - A->C (align_ac)

The alignments are consistent if: align_ac ≈ align_ab @ align_bc

This loss penalizes violations of this transitivity property, which is important for producing coherent multiple sequence alignments.

Parameters:

Name Type Description Default
rngs Rngs | None

Flax NNX random number generators.

None

Parameters:

Name Type Description Default
rngs Rngs | None

Random number generators (optional).

None

__call__ ¤

__call__(
    align_ab: Float[Array, "len_a len_b"],
    align_bc: Float[Array, "len_b len_c"],
    align_ac: Float[Array, "len_a len_c"],
) -> Float[Array, ""]

Compute alignment consistency loss.

Parameters:

Name Type Description Default
align_ab Float[Array, 'len_a len_b']

Soft alignment from sequence A to B.

required
align_bc Float[Array, 'len_b len_c']

Soft alignment from sequence B to C.

required
align_ac Float[Array, 'len_a len_c']

Soft alignment from sequence A to C.

required

Returns:

Type Description
Float[Array, '']

Scalar loss measuring transitivity violation.

Usage Example¤

from diffbio.losses import AlignmentScoreLoss, SoftEditDistanceLoss

# Alignment score loss
alignment_loss = AlignmentScoreLoss()
loss = alignment_loss(
    seq1=sequence1,
    seq2=sequence2,
    alignment=alignment_matrix,
)

# Soft edit distance
edit_loss = SoftEditDistanceLoss(normalize=True, temperature=0.1)
distance = edit_loss(seq1=sequence1, seq2=sequence2)