CRISPR Guide Design Operators¤
DiffBio provides differentiable operators for CRISPR guide RNA design and scoring, enabling gradient-based optimization of guide selection.
CRISPR Fully Differentiable
Overview¤
CRISPR guide design operators enable end-to-end optimization of:
- DifferentiableCRISPRScorer: DeepCRISPR-inspired CNN for on-target efficiency prediction
DifferentiableCRISPRScorer¤
CNN-based guide RNA scoring that predicts on-target cleavage efficiency from sequence features.
Quick Start¤
from flax import nnx
import jax
import jax.numpy as jnp
from diffbio.operators.crispr import (
DifferentiableCRISPRScorer,
CRISPRScorerConfig,
create_crispr_scorer,
)
# Configure scorer
config = CRISPRScorerConfig(
guide_length=23, # 20nt guide + 3nt PAM
hidden_channels=(64, 128, 256), # CNN channels
fc_dims=(256, 128), # Fully connected layers
dropout_rate=0.2, # Regularization
)
# Create operator
rngs = nnx.Rngs(42)
scorer = DifferentiableCRISPRScorer(config, rngs=rngs)
# Prepare guide sequences (one-hot encoded)
# A=0, C=1, G=2, T=3
guide_indices = jax.random.randint(jax.random.PRNGKey(0), (100, 23), 0, 4)
guides = jax.nn.one_hot(guide_indices, 4) # (n_guides, length, 4)
# Apply scoring
data = {"guides": guides}
result, state, metadata = scorer.apply(data, {}, None)
# Get efficiency scores
scores = result["efficiency_scores"] # (n_guides,) in [0, 1]
Configuration¤
| Parameter | Type | Default | Description |
|---|---|---|---|
guide_length |
int | 23 | Length of guide RNA + PAM context |
alphabet_size |
int | 4 | Nucleotide alphabet size (A/C/G/T) |
hidden_channels |
tuple[int, ...] | (64, 128, 256) | CNN hidden channel dimensions |
fc_dims |
tuple[int, ...] | (256, 128) | Fully connected layer dimensions |
dropout_rate |
float | 0.2 | Dropout rate for regularization |
Architecture¤
The scorer uses a 1D CNN architecture inspired by DeepCRISPR:
graph LR
A["Guides<br/>(n, 23, 4)"] --> B["Conv1D + BN + ReLU"]
B --> C["Conv1D + BN + ReLU"]
C --> D["Conv1D + BN + ReLU"]
D --> E["Flatten"]
E --> F["FC + ReLU + Dropout"]
F --> G["FC + Sigmoid"]
G --> H["Efficiency Score<br/>(n,) [0,1]"]
style A fill:#d1fae5,stroke:#059669,color:#064e3b
style B fill:#e0e7ff,stroke:#4338ca,color:#312e81
style C fill:#e0e7ff,stroke:#4338ca,color:#312e81
style D fill:#e0e7ff,stroke:#4338ca,color:#312e81
style E fill:#e0e7ff,stroke:#4338ca,color:#312e81
style F fill:#e0e7ff,stroke:#4338ca,color:#312e81
style G fill:#e0e7ff,stroke:#4338ca,color:#312e81
style H fill:#d1fae5,stroke:#059669,color:#064e3b
The architecture consists of:
- 1D Convolutional layers: Extract sequence motif patterns
- Batch normalization: Stabilize training
- Fully connected layers: Map features to efficiency score
- Sigmoid output: Bound score to [0, 1]
Guide Encoding¤
Guides are one-hot encoded with channels for each nucleotide:
# One-hot encoding: A=[1,0,0,0], C=[0,1,0,0], G=[0,0,1,0], T=[0,0,0,1]
import jax.numpy as jnp
# From sequence string
def encode_guide(sequence):
mapping = {'A': 0, 'C': 1, 'G': 2, 'T': 3}
indices = jnp.array([mapping[nt] for nt in sequence])
return jax.nn.one_hot(indices, 4)
# Example: 20nt guide + NGG PAM
guide_seq = "ATCGATCGATCGATCGATCG" + "AGG" # 23nt total
encoded = encode_guide(guide_seq) # (23, 4)
Training¤
import optax
from flax import nnx
scorer = create_crispr_scorer(guide_length=23)
optimizer = optax.adam(1e-3)
opt_state = optimizer.init(nnx.state(scorer, nnx.Param))
def loss_fn(model, guides, target_scores):
"""MSE loss for efficiency prediction."""
result, _, _ = model.apply({"guides": guides}, {}, None)
predicted = result["efficiency_scores"]
return jnp.mean((predicted - target_scores) ** 2)
@jax.jit
def train_step(model, opt_state, guides, targets):
loss, grads = jax.value_and_grad(loss_fn)(model, guides, targets)
params = nnx.state(model, nnx.Param)
updates, opt_state = optimizer.update(grads, opt_state, params)
nnx.update(model, optax.apply_updates(params, updates))
return loss, opt_state
# Training loop
scorer.train()
for epoch in range(100):
loss, opt_state = train_step(scorer, opt_state, train_guides, train_scores)
scorer.eval()
Inference¤
scorer.eval()
# Score a batch of guides
result, _, _ = scorer.apply({"guides": test_guides}, {}, None)
scores = result["efficiency_scores"]
# Rank guides by predicted efficiency
ranked_indices = jnp.argsort(scores)[::-1]
top_guides = test_guides[ranked_indices[:10]]
top_scores = scores[ranked_indices[:10]]
print("Top 10 guides by predicted efficiency:")
for i, (guide, score) in enumerate(zip(top_guides, top_scores)):
print(f"{i+1}. Score: {score:.3f}")
Batch Processing¤
# Process many guides efficiently
batch_size = 1000
all_scores = []
for i in range(0, len(all_guides), batch_size):
batch = all_guides[i:i+batch_size]
result, _, _ = scorer.apply({"guides": batch}, {}, None)
all_scores.append(result["efficiency_scores"])
all_scores = jnp.concatenate(all_scores)
CRISPR Guide Design Workflow¤
Complete Guide Selection Pipeline¤
from diffbio.operators.crispr import create_crispr_scorer
from diffbio.operators.alignment import SmoothSmithWaterman
# 1. Generate candidate guides from target sequence
def generate_candidates(target_seq, pam="NGG"):
"""Find all PAM sites and extract guide sequences."""
candidates = []
for i in range(len(target_seq) - 22):
if target_seq[i+20:i+23] in ["AGG", "CGG", "GGG", "TGG"]:
guide = target_seq[i:i+23]
candidates.append(guide)
return candidates
# 2. Score guides for on-target efficiency
scorer = create_crispr_scorer(guide_length=23)
scorer.eval()
candidate_seqs = generate_candidates(target_dna)
encoded = jnp.stack([encode_guide(seq) for seq in candidate_seqs])
result, _, _ = scorer.apply({"guides": encoded}, {}, None)
efficiency_scores = result["efficiency_scores"]
# 3. Filter by efficiency threshold
threshold = 0.7
good_guides = [(seq, score) for seq, score in
zip(candidate_seqs, efficiency_scores) if score > threshold]
print(f"Found {len(good_guides)} guides with efficiency > {threshold}")
Use Cases¤
| Application | Operator | Description |
|---|---|---|
| Guide selection | DifferentiableCRISPRScorer | Rank guides by efficiency |
| Library design | DifferentiableCRISPRScorer | Score large guide libraries |
| Optimization | DifferentiableCRISPRScorer | Gradient-based guide design |
References¤
-
Chuai et al. (2018). "DeepCRISPR: Optimized CRISPR guide RNA design by deep learning." Genome Biology.
-
Liu et al. (2021). "Enhancing CRISPR-Cas9 gRNA efficiency prediction by data integration and deep learning." Nature Communications.
-
Wessels et al. (2020). "Massively parallel Cas13 screens reveal principles for guide RNA design." Nature Biotechnology.
Next Steps¤
- See Alignment Operators for off-target search
- Explore Preprocessing Operators for sequence preparation