CRISPR Operators API¤
Differentiable operators for CRISPR guide RNA design and on-target efficiency prediction.
DifferentiableCRISPRScorer¤
diffbio.operators.crispr.guide_scoring.DifferentiableCRISPRScorer
¤
DifferentiableCRISPRScorer(
config: CRISPRScorerConfig, *, rngs: Rngs
)
Bases: OperatorModule
DeepCRISPR-style differentiable guide RNA scoring.
This operator uses a 1D CNN architecture to predict CRISPR guide RNA on-target efficiency from sequence features. The model learns sequence patterns that correlate with efficient target cleavage.
The architecture consists of: 1. 1D convolutional layers for sequence feature extraction 2. Batch normalization and ReLU activations 3. Fully connected layers for efficiency prediction 4. Sigmoid output for efficiency score in [0, 1]
Attributes:
| Name | Type | Description |
|---|---|---|
config |
Operator configuration. |
|
conv_layers |
1D convolutional layers. |
|
conv_bn |
Batch normalization layers for conv. |
|
ffn_backbone |
Shared Artifex MLP for score prediction. |
|
output_head |
Final output layer. |
Example
from diffbio.operators.crispr import (
DifferentiableCRISPRScorer,
CRISPRScorerConfig,
)
config = CRISPRScorerConfig(guide_length=23)
scorer = DifferentiableCRISPRScorer(config, rngs=nnx.Rngs(42))
data = {"guides": guide_sequences} # (n_guides, length, 4)
result, _, _ = scorer.apply(data, {}, None)
scores = result["efficiency_scores"] # (n_guides,)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
CRISPRScorerConfig
|
Operator configuration. |
required |
rngs
|
Rngs
|
Flax NNX random number generators. |
required |
apply
¤
apply(
data: dict[str, Any],
state: dict[str, Any],
metadata: dict[str, Any] | None,
random_params: Any = None,
stats: dict[str, Any] | None = None,
) -> tuple[
dict[str, Any], dict[str, Any], dict[str, Any] | None
]
Apply CRISPR scoring to guide sequences.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
dict[str, Any]
|
Dictionary containing: - "guides": One-hot encoded guides (n_guides, guide_length, 4). |
required |
state
|
dict[str, Any]
|
Per-element state (passed through). |
required |
metadata
|
dict[str, Any] | None
|
Optional metadata (passed through). |
required |
random_params
|
Any
|
Random parameters for stochastic operations. |
None
|
stats
|
dict[str, Any] | None
|
Optional statistics dictionary. |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
dict[str, Any]
|
Tuple of (transformed_data, state, metadata) where transformed_data |
|
contains |
dict[str, Any]
|
|
extract_features
¤
Extract features from guide sequences using CNN.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
guides
|
ndarray
|
One-hot encoded guides (n_guides, guide_length, 4). |
required |
Returns:
| Type | Description |
|---|---|
ndarray
|
Feature vectors (n_guides, feature_dim). |
predict_efficiency
¤
Predict efficiency score from features.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
features
|
ndarray
|
Feature vectors (n_guides, feature_dim). |
required |
Returns:
| Type | Description |
|---|---|
ndarray
|
Efficiency scores (n_guides,) in range [0, 1]. |
CRISPRScorerConfig¤
diffbio.operators.crispr.guide_scoring.CRISPRScorerConfig
dataclass
¤
CRISPRScorerConfig(
guide_length: int = 23,
alphabet_size: int = 4,
hidden_channels: tuple[int, ...] = (64, 128, 256),
fc_dims: tuple[int, ...] = (256, 128),
dropout_rate: float = 0.2,
)
Bases: OperatorConfig
Configuration for DifferentiableCRISPRScorer.
Attributes:
| Name | Type | Description |
|---|---|---|
guide_length |
int
|
Length of guide RNA sequence (typically 20-23 nt). |
alphabet_size |
int
|
Size of nucleotide alphabet (4 for A/C/G/T). |
hidden_channels |
tuple[int, ...]
|
CNN hidden channel dimensions. |
fc_dims |
tuple[int, ...]
|
Fully connected layer dimensions. |
dropout_rate |
float
|
Dropout rate for regularization. |
create_crispr_scorer¤
diffbio.operators.crispr.guide_scoring.create_crispr_scorer
¤
create_crispr_scorer(
guide_length: int = 23,
hidden_channels: tuple[int, ...] = (64, 128, 256),
fc_dims: tuple[int, ...] = (256, 128),
dropout_rate: float = 0.2,
seed: int = 42,
) -> DifferentiableCRISPRScorer
Factory function to create a CRISPR scorer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
guide_length
|
int
|
Length of guide RNA sequence. |
23
|
hidden_channels
|
tuple[int, ...]
|
CNN hidden channel dimensions. |
(64, 128, 256)
|
fc_dims
|
tuple[int, ...]
|
Fully connected layer dimensions. |
(256, 128)
|
dropout_rate
|
float
|
Dropout rate for regularization. |
0.2
|
seed
|
int
|
Random seed for initialization. |
42
|
Returns:
| Type | Description |
|---|---|
DifferentiableCRISPRScorer
|
Configured DifferentiableCRISPRScorer instance. |
Usage Examples¤
Basic Guide Scoring¤
from flax import nnx
import jax
from diffbio.operators.crispr import (
DifferentiableCRISPRScorer,
CRISPRScorerConfig,
create_crispr_scorer,
)
# Using config
config = CRISPRScorerConfig(
guide_length=23,
hidden_channels=(64, 128, 256),
fc_dims=(256, 128),
)
scorer = DifferentiableCRISPRScorer(config, rngs=nnx.Rngs(42))
# Or using factory function
scorer = create_crispr_scorer(guide_length=23)
# Score guides
guide_indices = jax.random.randint(jax.random.PRNGKey(0), (100, 23), 0, 4)
guides = jax.nn.one_hot(guide_indices, 4)
result, _, _ = scorer.apply({"guides": guides}, {}, None)
scores = result["efficiency_scores"] # (100,) in [0, 1]
Training Mode¤
# Enable dropout during training
scorer.train()
for batch in train_dataloader:
loss = train_step(scorer, batch)
# Disable dropout for inference
scorer.eval()
Accessing Components¤
# Convolutional layers
conv_layers = scorer.conv_layers
# Fully connected layers
fc_layers = scorer.fc_layers
# Output head
output_head = scorer.output_head
Input Specifications¤
| Key | Shape | Description |
|---|---|---|
guides |
(n_guides, guide_length, 4) | One-hot encoded guide sequences |
Output Specifications¤
| Key | Shape | Description |
|---|---|---|
guides |
(n_guides, guide_length, 4) | Original guide sequences |
efficiency_scores |
(n_guides,) | Predicted efficiency scores in [0, 1] |
features |
(n_guides, feature_dim) | Extracted CNN features |