Skip to content

CRISPR Operators API¤

Differentiable operators for CRISPR guide RNA design and on-target efficiency prediction.

DifferentiableCRISPRScorer¤

diffbio.operators.crispr.guide_scoring.DifferentiableCRISPRScorer ¤

DifferentiableCRISPRScorer(
    config: CRISPRScorerConfig, *, rngs: Rngs
)

Bases: OperatorModule

DeepCRISPR-style differentiable guide RNA scoring.

This operator uses a 1D CNN architecture to predict CRISPR guide RNA on-target efficiency from sequence features. The model learns sequence patterns that correlate with efficient target cleavage.

The architecture consists of: 1. 1D convolutional layers for sequence feature extraction 2. Batch normalization and ReLU activations 3. Fully connected layers for efficiency prediction 4. Sigmoid output for efficiency score in [0, 1]

Attributes:

Name Type Description
config

Operator configuration.

conv_layers

1D convolutional layers.

conv_bn

Batch normalization layers for conv.

ffn_backbone

Shared Artifex MLP for score prediction.

output_head

Final output layer.

Example
from diffbio.operators.crispr import (
    DifferentiableCRISPRScorer,
    CRISPRScorerConfig,
)
config = CRISPRScorerConfig(guide_length=23)
scorer = DifferentiableCRISPRScorer(config, rngs=nnx.Rngs(42))
data = {"guides": guide_sequences}  # (n_guides, length, 4)
result, _, _ = scorer.apply(data, {}, None)
scores = result["efficiency_scores"]  # (n_guides,)

Parameters:

Name Type Description Default
config CRISPRScorerConfig

Operator configuration.

required
rngs Rngs

Flax NNX random number generators.

required

apply ¤

apply(
    data: dict[str, Any],
    state: dict[str, Any],
    metadata: dict[str, Any] | None,
    random_params: Any = None,
    stats: dict[str, Any] | None = None,
) -> tuple[
    dict[str, Any], dict[str, Any], dict[str, Any] | None
]

Apply CRISPR scoring to guide sequences.

Parameters:

Name Type Description Default
data dict[str, Any]

Dictionary containing: - "guides": One-hot encoded guides (n_guides, guide_length, 4).

required
state dict[str, Any]

Per-element state (passed through).

required
metadata dict[str, Any] | None

Optional metadata (passed through).

required
random_params Any

Random parameters for stochastic operations.

None
stats dict[str, Any] | None

Optional statistics dictionary.

None

Returns:

Name Type Description
dict[str, Any]

Tuple of (transformed_data, state, metadata) where transformed_data

contains dict[str, Any]
  • "guides": Original guide sequences.
  • "efficiency_scores": Predicted efficiency (n_guides,).
  • "features": Extracted feature vectors.

extract_features ¤

extract_features(guides: ndarray) -> ndarray

Extract features from guide sequences using CNN.

Parameters:

Name Type Description Default
guides ndarray

One-hot encoded guides (n_guides, guide_length, 4).

required

Returns:

Type Description
ndarray

Feature vectors (n_guides, feature_dim).

predict_efficiency ¤

predict_efficiency(features: ndarray) -> ndarray

Predict efficiency score from features.

Parameters:

Name Type Description Default
features ndarray

Feature vectors (n_guides, feature_dim).

required

Returns:

Type Description
ndarray

Efficiency scores (n_guides,) in range [0, 1].

CRISPRScorerConfig¤

diffbio.operators.crispr.guide_scoring.CRISPRScorerConfig dataclass ¤

CRISPRScorerConfig(
    guide_length: int = 23,
    alphabet_size: int = 4,
    hidden_channels: tuple[int, ...] = (64, 128, 256),
    fc_dims: tuple[int, ...] = (256, 128),
    dropout_rate: float = 0.2,
)

Bases: OperatorConfig

Configuration for DifferentiableCRISPRScorer.

Attributes:

Name Type Description
guide_length int

Length of guide RNA sequence (typically 20-23 nt).

alphabet_size int

Size of nucleotide alphabet (4 for A/C/G/T).

hidden_channels tuple[int, ...]

CNN hidden channel dimensions.

fc_dims tuple[int, ...]

Fully connected layer dimensions.

dropout_rate float

Dropout rate for regularization.

create_crispr_scorer¤

diffbio.operators.crispr.guide_scoring.create_crispr_scorer ¤

create_crispr_scorer(
    guide_length: int = 23,
    hidden_channels: tuple[int, ...] = (64, 128, 256),
    fc_dims: tuple[int, ...] = (256, 128),
    dropout_rate: float = 0.2,
    seed: int = 42,
) -> DifferentiableCRISPRScorer

Factory function to create a CRISPR scorer.

Parameters:

Name Type Description Default
guide_length int

Length of guide RNA sequence.

23
hidden_channels tuple[int, ...]

CNN hidden channel dimensions.

(64, 128, 256)
fc_dims tuple[int, ...]

Fully connected layer dimensions.

(256, 128)
dropout_rate float

Dropout rate for regularization.

0.2
seed int

Random seed for initialization.

42

Returns:

Type Description
DifferentiableCRISPRScorer

Configured DifferentiableCRISPRScorer instance.

Example
scorer = create_crispr_scorer(guide_length=23)
result, _, _ = scorer.apply({"guides": data}, {}, None)

Usage Examples¤

Basic Guide Scoring¤

from flax import nnx
import jax
from diffbio.operators.crispr import (
    DifferentiableCRISPRScorer,
    CRISPRScorerConfig,
    create_crispr_scorer,
)

# Using config
config = CRISPRScorerConfig(
    guide_length=23,
    hidden_channels=(64, 128, 256),
    fc_dims=(256, 128),
)
scorer = DifferentiableCRISPRScorer(config, rngs=nnx.Rngs(42))

# Or using factory function
scorer = create_crispr_scorer(guide_length=23)

# Score guides
guide_indices = jax.random.randint(jax.random.PRNGKey(0), (100, 23), 0, 4)
guides = jax.nn.one_hot(guide_indices, 4)

result, _, _ = scorer.apply({"guides": guides}, {}, None)
scores = result["efficiency_scores"]  # (100,) in [0, 1]

Training Mode¤

# Enable dropout during training
scorer.train()

for batch in train_dataloader:
    loss = train_step(scorer, batch)

# Disable dropout for inference
scorer.eval()

Accessing Components¤

# Convolutional layers
conv_layers = scorer.conv_layers

# Fully connected layers
fc_layers = scorer.fc_layers

# Output head
output_head = scorer.output_head

Input Specifications¤

Key Shape Description
guides (n_guides, guide_length, 4) One-hot encoded guide sequences

Output Specifications¤

Key Shape Description
guides (n_guides, guide_length, 4) Original guide sequences
efficiency_scores (n_guides,) Predicted efficiency scores in [0, 1]
features (n_guides, feature_dim) Extracted CNN features