Drug Discovery Operators API¤
Differentiable operators for molecular property prediction, fingerprint computation, and similarity scoring.
MolecularPropertyPredictor¤
diffbio.operators.drug_discovery.property_predictor.MolecularPropertyPredictor
¤
MolecularPropertyPredictor(
config: MolecularPropertyConfig,
*,
rngs: Rngs | None = None,
)
Bases: OperatorModule
ChemProp-style molecular property predictor.
Implements a directed message passing neural network (D-MPNN) for predicting molecular properties from graph representations.
The architecture consists of: 1. Message passing layers to compute atom representations 2. Graph-level readout via sum pooling 3. Feed-forward network for property prediction
Example
config = MolecularPropertyConfig(hidden_dim=64, num_output_tasks=3)
predictor = MolecularPropertyPredictor(config, rngs=nnx.Rngs(42))
data = {
"node_features": node_features,
"adjacency": adjacency,
"node_mask": mask,
}
result, state, meta = predictor.apply(data, {}, None)
predictions = result["predictions"] # shape: (3,)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
MolecularPropertyConfig
|
Predictor configuration. |
required |
rngs
|
Rngs | None
|
Flax NNX random number generators. |
None
|
apply
¤
apply(
data: dict[str, Any],
state: dict[str, Any],
metadata: dict[str, Any] | None,
random_params: Any = None,
stats: dict[str, Any] | None = None,
) -> tuple[
dict[str, Any], dict[str, Any], dict[str, Any] | None
]
Predict molecular properties from graph representation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
dict[str, Any]
|
Input data containing: - node_features: (num_nodes, num_features) atom features - adjacency: (num_nodes, num_nodes) adjacency matrix - edge_features: Optional (num_nodes, num_nodes, num_edge_features) - node_mask: (num_nodes,) mask for valid nodes |
required |
state
|
dict[str, Any]
|
Per-element state (passed through). |
required |
metadata
|
dict[str, Any] | None
|
Optional metadata. |
required |
random_params
|
Any
|
Unused random parameters. |
None
|
stats
|
dict[str, Any] | None
|
Optional statistics dictionary. |
None
|
Returns:
| Type | Description |
|---|---|
tuple[dict[str, Any], dict[str, Any], dict[str, Any] | None]
|
Tuple of: - data with added "predictions" key - unchanged state - unchanged metadata |
MolecularPropertyConfig¤
diffbio.operators.drug_discovery.property_predictor.MolecularPropertyConfig
dataclass
¤
MolecularPropertyConfig(
hidden_dim: int = 300,
num_message_passing_steps: int = 3,
num_output_tasks: int = 1,
dropout_rate: float = 0.0,
in_features: int = 4,
num_edge_features: int = 4,
)
Bases: OperatorConfig
Configuration for molecular property predictor.
Attributes:
| Name | Type | Description |
|---|---|---|
hidden_dim |
int
|
Hidden dimension for message passing layers. |
num_message_passing_steps |
int
|
Number of message passing iterations. |
num_output_tasks |
int
|
Number of prediction tasks (multi-task learning). |
dropout_rate |
float
|
Dropout rate for regularization. |
in_features |
int
|
Number of input node features (default: DEFAULT_ATOM_FEATURES=34). |
num_edge_features |
int
|
Number of edge/bond features. |
ADMETPredictor¤
diffbio.operators.drug_discovery.admet_predictor.ADMETPredictor
¤
ADMETPredictor(
config: ADMETConfig,
*,
rngs: Rngs | None = None,
name: str | None = None,
)
Bases: OperatorModule
Multi-task ADMET property predictor.
Implements a ChemProp-style directed message passing neural network for predicting multiple ADMET properties simultaneously. The architecture uses a shared molecular encoder with task-specific prediction heads.
Architecture
- Message passing encoder (D-MPNN style)
- Graph-level readout via sum pooling
- Shared feed-forward layers
- Task-specific output heads
The 22 standard TDC ADMET endpoints cover
- Absorption: Caco2, HIA, Pgp, Bioavailability, Lipophilicity, Solubility
- Distribution: BBB, PPBR, VDss
- Metabolism: CYP enzymes (2C9, 2D6, 3A4) inhibition and substrate
- Excretion: Half-life, Hepatocyte clearance, Microsome clearance
- Toxicity: LD50, hERG, AMES, DILI
Example
References
- https://tdcommons.ai/benchmark/admet_group/overview/
- Yang et al. "Analyzing Learned Molecular Representations" JCIM 2019
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
ADMETConfig
|
ADMET configuration. |
required |
rngs
|
Rngs | None
|
Flax NNX random number generators. |
None
|
name
|
str | None
|
Optional name for the operator. |
None
|
apply
¤
apply(
data: dict[str, Any],
state: dict[str, Any],
metadata: dict[str, Any] | None,
random_params: Any = None,
stats: dict[str, Any] | None = None,
) -> tuple[
dict[str, Any], dict[str, Any], dict[str, Any] | None
]
Predict ADMET properties from molecular graph.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
dict[str, Any]
|
Input data containing: - node_features: (num_nodes, num_features) atom features - adjacency: (num_nodes, num_nodes) adjacency matrix - edge_features: Optional (num_nodes, num_nodes, num_edge_features) - node_mask: (num_nodes,) mask for valid nodes |
required |
state
|
dict[str, Any]
|
Per-element state (passed through). |
required |
metadata
|
dict[str, Any] | None
|
Optional metadata. |
required |
random_params
|
Any
|
Unused random parameters. |
None
|
stats
|
dict[str, Any] | None
|
Optional statistics dictionary. |
None
|
Returns:
| Type | Description |
|---|---|
tuple[dict[str, Any], dict[str, Any], dict[str, Any] | None]
|
Tuple of: - data with added "predictions" and "task_predictions" keys - unchanged state - unchanged metadata |
ADMETConfig¤
diffbio.operators.drug_discovery.admet_predictor.ADMETConfig
dataclass
¤
ADMETConfig(
hidden_dim: int = 300,
num_message_passing_steps: int = 3,
num_tasks: int = 22,
dropout_rate: float = 0.0,
in_features: int = 4,
num_edge_features: int = 4,
ffn_hidden_dim: int | None = None,
ffn_num_layers: int = 2,
apply_task_activations: bool = False,
)
Bases: OperatorConfig
Configuration for ADMET property predictor.
Attributes:
| Name | Type | Description |
|---|---|---|
hidden_dim |
int
|
Hidden dimension for message passing (default: 300). |
num_message_passing_steps |
int
|
Number of D-MPNN iterations (default: 3). |
num_tasks |
int
|
Number of ADMET prediction tasks (default: 22). |
dropout_rate |
float
|
Dropout rate for regularization (default: 0.0). |
in_features |
int
|
Number of input node features (default: 4). |
num_edge_features |
int
|
Number of edge features (default: 4). |
ffn_hidden_dim |
int | None
|
FFN hidden dimension (default: same as hidden_dim). |
ffn_num_layers |
int
|
Number of FFN layers (default: 2). |
apply_task_activations |
bool
|
Apply sigmoid for classification tasks (default: False). |
MACCSKeysOperator¤
diffbio.operators.drug_discovery.maccs_keys.MACCSKeysOperator
¤
MACCSKeysOperator(
config: MACCSKeysConfig, *, rngs: Rngs | None = None
)
Bases: OperatorModule
Differentiable MACCS structural keys fingerprint operator.
For differentiable=True: Uses message passing and learned pattern detectors to approximate MACCS key detection. Each of the 166 keys is represented by a learned pattern matching network that outputs a soft presence score.
For differentiable=False: Would use RDKit's exact MACCS implementation (not differentiable).
The differentiable version enables gradient flow for end-to-end optimization while approximating the structural pattern detection of traditional MACCS keys.
MACCS keys encode various structural features:
- Atom types (C, N, O, S, halides, etc.)
- Functional groups (carbonyl, hydroxyl, amine, etc.)
- Ring systems (aromatic, aliphatic)
- Bond patterns and connectivity
Example
References
- Durant et al. "Reoptimization of MDL Keys" JCIM 2002
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
MACCSKeysConfig
|
MACCS keys configuration. |
required |
rngs
|
Rngs | None
|
Flax NNX random number generators. |
None
|
apply
¤
apply(
data: dict[str, Any],
state: dict[str, Any],
metadata: dict[str, Any] | None,
random_params: Any = None,
stats: dict[str, Any] | None = None,
) -> tuple[
dict[str, Any], dict[str, Any], dict[str, Any] | None
]
Compute MACCS keys fingerprint.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
dict[str, Any]
|
Input data containing: For differentiable=True: - node_features: (num_nodes, num_features) atom features - adjacency: (num_nodes, num_nodes) adjacency matrix - node_mask: (num_nodes,) optional mask for valid nodes For differentiable=False: - smiles: SMILES string |
required |
state
|
dict[str, Any]
|
Per-element state (passed through). |
required |
metadata
|
dict[str, Any] | None
|
Optional metadata. |
required |
random_params
|
Any
|
Unused random parameters. |
None
|
stats
|
dict[str, Any] | None
|
Optional statistics dictionary. |
None
|
Returns:
| Type | Description |
|---|---|
tuple[dict[str, Any], dict[str, Any], dict[str, Any] | None]
|
Tuple of: - data with added "fingerprint" key - unchanged state - unchanged metadata |
MACCSKeysConfig¤
diffbio.operators.drug_discovery.maccs_keys.MACCSKeysConfig
dataclass
¤
MACCSKeysConfig(
n_bits: int = 166,
differentiable: bool = True,
temperature: float = 1.0,
hidden_dim: int = 64,
num_layers: int = 2,
in_features: int = 4,
)
Bases: OperatorConfig
Configuration for MACCS keys fingerprint operator.
Attributes:
| Name | Type | Description |
|---|---|---|
n_bits |
int
|
Number of fingerprint bits (default: 166 for standard MACCS). |
differentiable |
bool
|
Use learned pattern matching (default: True). |
temperature |
float
|
Temperature for soft bit assignment (default: 1.0). |
hidden_dim |
int
|
Hidden dimension for pattern networks (default: 64). |
num_layers |
int
|
Number of message passing layers (default: 2). |
in_features |
int
|
Number of input node features (default: 4). |
AttentiveFP¤
diffbio.operators.drug_discovery.attentive_fp.AttentiveFP
¤
AttentiveFP(
config: AttentiveFPConfig, *, rngs: Rngs | None = None
)
Bases: OperatorModule
AttentiveFP: Attention-based molecular fingerprint.
Implements the AttentiveFP architecture with
- Atom-level attention layers with GRU refinement
- Molecule-level aggregation with attention and GRU
- Final projection to fingerprint dimension
The model provides interpretable attention weights that indicate which atoms contribute most to the molecular representation.
Example
config = AttentiveFPConfig(hidden_dim=128, out_dim=256)
afp = AttentiveFP(config, rngs=nnx.Rngs(42))
data = {"node_features": nodes, "adjacency": adj, "edge_features": edges}
result, _, _ = afp.apply(data, {}, None)
fingerprint = result["fingerprint"] # (256,)
attn = result["attention_weights"] # interpretability
References
- Xiong et al. JCIM 2019
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
AttentiveFPConfig
|
AttentiveFP configuration. |
required |
rngs
|
Rngs | None
|
Flax NNX random number generators. |
None
|
apply
¤
apply(
data: dict[str, Any],
state: dict[str, Any],
metadata: dict[str, Any] | None,
random_params: Any = None,
stats: dict[str, Any] | None = None,
) -> tuple[
dict[str, Any], dict[str, Any], dict[str, Any] | None
]
Compute AttentiveFP molecular fingerprint.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
dict[str, Any]
|
Input data containing: - node_features: (num_nodes, in_features) atom features - adjacency: (num_nodes, num_nodes) adjacency matrix - edge_features: Optional (num_nodes, num_nodes, edge_dim) - node_mask: (num_nodes,) optional mask for valid nodes |
required |
state
|
dict[str, Any]
|
Per-element state (passed through). |
required |
metadata
|
dict[str, Any] | None
|
Optional metadata. |
required |
random_params
|
Any
|
Unused random parameters. |
None
|
stats
|
dict[str, Any] | None
|
Optional statistics dictionary. |
None
|
Returns:
| Type | Description |
|---|---|
tuple[dict[str, Any], dict[str, Any], dict[str, Any] | None]
|
Tuple of: - data with "fingerprint" and "attention_weights" keys - unchanged state - unchanged metadata |
AttentiveFPConfig¤
diffbio.operators.drug_discovery.attentive_fp.AttentiveFPConfig
dataclass
¤
AttentiveFPConfig(
in_features: int = 39,
edge_dim: int = 10,
negative_slope: float = 0.2,
hidden_dim: int = 200,
out_dim: int = 200,
num_layers: int = 2,
num_timesteps: int = 2,
dropout_rate: float = 0.0,
)
Bases: _AttentiveFPArchitectureConfig, _AttentiveFPInputConfig, OperatorConfig
Configuration for AttentiveFP operator.
DifferentiableMolecularFingerprint¤
diffbio.operators.drug_discovery.fingerprint.DifferentiableMolecularFingerprint
¤
DifferentiableMolecularFingerprint(
config: MolecularFingerprintConfig,
*,
rngs: Rngs | None = None,
name: str | None = None,
)
Bases: OperatorModule
Neural graph fingerprint operator.
Computes learned molecular fingerprints using graph neural networks. Unlike traditional fingerprints (e.g., ECFP/Morgan), these are fully differentiable and can be optimized for specific tasks.
The fingerprint is computed by: 1. Message passing to compute atom representations 2. Sum pooling to get graph-level representation 3. Linear projection to fingerprint dimension 4. Optional L2 normalization
Example
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
MolecularFingerprintConfig
|
Fingerprint configuration. |
required |
rngs
|
Rngs | None
|
Flax NNX random number generators. |
None
|
name
|
str | None
|
Optional name for the operator. |
None
|
apply
¤
apply(
data: dict[str, Any],
state: dict[str, Any],
metadata: dict[str, Any] | None,
random_params: Any = None,
stats: dict[str, Any] | None = None,
) -> tuple[
dict[str, Any], dict[str, Any], dict[str, Any] | None
]
Compute molecular fingerprint.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
dict[str, Any]
|
Input data containing: - node_features: (num_nodes, num_features) atom features - adjacency: (num_nodes, num_nodes) adjacency matrix - node_mask: (num_nodes,) mask for valid nodes |
required |
state
|
dict[str, Any]
|
Per-element state (passed through). |
required |
metadata
|
dict[str, Any] | None
|
Optional metadata. |
required |
random_params
|
Any
|
Unused random parameters. |
None
|
stats
|
dict[str, Any] | None
|
Optional statistics dictionary. |
None
|
Returns:
| Type | Description |
|---|---|
tuple[dict[str, Any], dict[str, Any], dict[str, Any] | None]
|
Tuple of: - data with added "fingerprint" key - unchanged state - unchanged metadata |
MolecularFingerprintConfig¤
diffbio.operators.drug_discovery.fingerprint.MolecularFingerprintConfig
dataclass
¤
MolecularFingerprintConfig(
fingerprint_dim: int = 256,
hidden_dim: int = 128,
num_layers: int = 3,
in_features: int = 4,
normalize: bool = False,
)
Bases: OperatorConfig
Configuration for molecular fingerprint operator.
Attributes:
| Name | Type | Description |
|---|---|---|
fingerprint_dim |
int
|
Dimension of output fingerprint vector. |
hidden_dim |
int
|
Hidden dimension for graph convolutions. |
num_layers |
int
|
Number of graph convolution layers. |
in_features |
int
|
Number of input node features (default: DEFAULT_ATOM_FEATURES=34). |
normalize |
bool
|
Whether to L2-normalize the fingerprint. |
CircularFingerprintOperator¤
diffbio.operators.drug_discovery.fingerprint.CircularFingerprintOperator
¤
CircularFingerprintOperator(
config: CircularFingerprintConfig,
*,
rngs: Rngs | None = None,
)
Bases: OperatorModule
Differentiable circular fingerprints (ECFP/Morgan).
For differentiable=True: Uses message passing to aggregate substructure information, then learned "soft hash" functions for bit assignment. Gradients flow through the entire computation.
For differentiable=False: Wraps RDKit implementation for exact ECFP. No gradient flow (useful for inference/comparison).
The differentiable version approximates ECFP behavior while enabling end-to-end optimization of the fingerprint representation.
Example
References
Rogers, David, and Mathew Hahn. "Extended-connectivity fingerprints." Journal of chemical information and modeling 50.5 (2010): 742-754.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
CircularFingerprintConfig
|
Circular fingerprint configuration. |
required |
rngs
|
Rngs | None
|
Flax NNX random number generators. |
None
|
apply
¤
apply(
data: dict[str, Any],
state: dict[str, Any],
metadata: dict[str, Any] | None,
random_params: Any = None,
stats: dict[str, Any] | None = None,
) -> tuple[
dict[str, Any], dict[str, Any], dict[str, Any] | None
]
Compute circular fingerprint.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
dict[str, Any]
|
Input data containing either: For differentiable=True: - node_features: (num_nodes, num_features) atom features - adjacency: (num_nodes, num_nodes) adjacency matrix - node_mask: (num_nodes,) optional mask for valid nodes For differentiable=False: - smiles: SMILES string |
required |
state
|
dict[str, Any]
|
Per-element state (passed through). |
required |
metadata
|
dict[str, Any] | None
|
Optional metadata. |
required |
random_params
|
Any
|
Unused random parameters. |
None
|
stats
|
dict[str, Any] | None
|
Optional statistics dictionary. |
None
|
Returns:
| Type | Description |
|---|---|
tuple[dict[str, Any], dict[str, Any], dict[str, Any] | None]
|
Tuple of: - data with added "fingerprint" key - unchanged state - unchanged metadata |
CircularFingerprintConfig¤
diffbio.operators.drug_discovery.fingerprint.CircularFingerprintConfig
dataclass
¤
CircularFingerprintConfig(
radius: int = 2,
n_bits: int = 2048,
use_chirality: bool = False,
use_bond_types: bool = True,
use_features: bool = False,
differentiable: bool = True,
hash_hidden_dim: int = 128,
temperature: float = 1.0,
in_features: int = 4,
)
Bases: OperatorConfig
Configuration for circular fingerprint operator (ECFP/Morgan).
Attributes:
| Name | Type | Description |
|---|---|---|
radius |
int
|
Fingerprint radius. ECFP4 = radius 2, ECFP6 = radius 3. |
n_bits |
int
|
Number of bits in fingerprint (default: 2048). |
use_chirality |
bool
|
Include chirality in fingerprint (default: False). |
use_bond_types |
bool
|
Include bond type information (default: True). |
use_features |
bool
|
Use pharmacophoric features (FCFP variant, default: False). |
differentiable |
bool
|
Use learned hash functions for gradients (default: True). |
hash_hidden_dim |
int
|
Hidden dimension for hash network (default: 128). |
temperature |
float
|
Temperature for soft bit assignment (default: 1.0). |
in_features |
int
|
Number of input node features (default: 4). |
MolecularSimilarityOperator¤
diffbio.operators.drug_discovery.similarity.MolecularSimilarityOperator
¤
MolecularSimilarityOperator(
config: MolecularSimilarityConfig,
*,
rngs: Rngs | None = None,
)
Bases: OperatorModule
Differentiable molecular similarity operator.
Computes similarity between molecular fingerprints using various differentiable metrics. Supports Tanimoto, cosine, and Dice similarity.
Example
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
MolecularSimilarityConfig
|
Similarity configuration. |
required |
rngs
|
Rngs | None
|
Flax NNX random number generators. |
None
|
apply
¤
apply(
data: dict[str, Any],
state: dict[str, Any],
metadata: dict[str, Any] | None,
random_params: Any = None,
stats: dict[str, Any] | None = None,
) -> tuple[
dict[str, Any], dict[str, Any], dict[str, Any] | None
]
Compute similarity between two fingerprints.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
dict[str, Any]
|
Input data containing: - fingerprint_a: First fingerprint vector - fingerprint_b: Second fingerprint vector |
required |
state
|
dict[str, Any]
|
Per-element state (passed through). |
required |
metadata
|
dict[str, Any] | None
|
Optional metadata. |
required |
random_params
|
Any
|
Unused random parameters. |
None
|
stats
|
dict[str, Any] | None
|
Optional statistics dictionary. |
None
|
Returns:
| Type | Description |
|---|---|
tuple[dict[str, Any], dict[str, Any], dict[str, Any] | None]
|
Tuple of: - data with added "similarity" key - unchanged state - unchanged metadata |
MolecularSimilarityConfig¤
diffbio.operators.drug_discovery.similarity.MolecularSimilarityConfig
dataclass
¤
Bases: OperatorConfig
Configuration for molecular similarity operator.
Attributes:
| Name | Type | Description |
|---|---|---|
similarity_type |
str
|
Type of similarity metric ("tanimoto", "cosine", "dice"). |
temperature |
float
|
Temperature for soft similarity (higher = sharper). |
Message Passing Layers¤
MessagePassingLayer¤
diffbio.operators.drug_discovery.message_passing.MessagePassingLayer
¤
MessagePassingLayer(
hidden_dim: int,
in_features: int = 4,
num_edge_features: int = 4,
*,
rngs: Rngs,
)
Bases: Module
Directed message passing layer for molecular graphs.
Implements the D-MPNN message passing scheme where messages are passed along directed edges. Each node aggregates messages from its neighbors and updates its representation.
Attributes:
| Name | Type | Description |
|---|---|---|
hidden_dim |
Dimension of hidden node representations. |
|
in_features |
Number of input node features. |
|
num_edge_features |
Number of edge features (default 4 for bond types). |
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_dim
|
int
|
Dimension of hidden representations. |
required |
in_features
|
int
|
Number of input node features (default 4 for tests). |
4
|
num_edge_features
|
int
|
Number of edge/bond features. |
4
|
rngs
|
Rngs
|
Flax NNX random number generators. |
required |
__call__
¤
__call__(
node_features: ndarray,
adjacency: ndarray,
edge_features: ndarray | None = None,
) -> ndarray
Perform one step of message passing.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
node_features
|
ndarray
|
Node features of shape (num_nodes, in_features). |
required |
adjacency
|
ndarray
|
Adjacency matrix of shape (num_nodes, num_nodes). |
required |
edge_features
|
ndarray | None
|
Optional edge features of shape (num_nodes, num_nodes, num_edge_features). |
None
|
Returns:
| Type | Description |
|---|---|
ndarray
|
Updated node features of shape (num_nodes, hidden_dim). |
StackedMessagePassing¤
diffbio.operators.drug_discovery.message_passing.StackedMessagePassing
¤
StackedMessagePassing(
hidden_dim: int,
num_layers: int,
in_features: int = 4,
num_edge_features: int = 4,
*,
rngs: Rngs,
)
Bases: Module
Stack of message passing layers.
Applies multiple rounds of message passing to capture higher-order neighborhood information.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_dim
|
int
|
Hidden dimension for all layers. |
required |
num_layers
|
int
|
Number of message passing iterations. |
required |
in_features
|
int
|
Number of input node features (default 4 for tests). |
4
|
num_edge_features
|
int
|
Number of edge features. |
4
|
rngs
|
Rngs
|
Flax NNX random number generators. |
required |
__call__
¤
__call__(
node_features: ndarray,
adjacency: ndarray,
edge_features: ndarray | None = None,
) -> ndarray
Apply multiple rounds of message passing.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
node_features
|
ndarray
|
Initial node features. |
required |
adjacency
|
ndarray
|
Adjacency matrix. |
required |
edge_features
|
ndarray | None
|
Optional edge features. |
None
|
Returns:
| Type | Description |
|---|---|
ndarray
|
Final node representations. |
Factory Functions¤
create_property_predictor¤
diffbio.operators.drug_discovery.property_predictor.create_property_predictor
¤
create_property_predictor(
hidden_dim: int = 300,
num_layers: int = 3,
num_tasks: int = 1,
dropout_rate: float = 0.0,
seed: int = 42,
) -> MolecularPropertyPredictor
Create a molecular property predictor.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_dim
|
int
|
Hidden dimension for message passing. |
300
|
num_layers
|
int
|
Number of message passing steps. |
3
|
num_tasks
|
int
|
Number of prediction tasks. |
1
|
dropout_rate
|
float
|
Dropout rate. |
0.0
|
seed
|
int
|
Random seed. |
42
|
Returns:
| Type | Description |
|---|---|
MolecularPropertyPredictor
|
Configured MolecularPropertyPredictor. |
create_fingerprint_operator¤
diffbio.operators.drug_discovery.fingerprint.create_fingerprint_operator
¤
create_fingerprint_operator(
fingerprint_dim: int = 256,
num_layers: int = 3,
normalize: bool = False,
seed: int = 42,
) -> DifferentiableMolecularFingerprint
Create a molecular fingerprint operator.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fingerprint_dim
|
int
|
Output fingerprint dimension. |
256
|
num_layers
|
int
|
Number of message passing layers. |
3
|
normalize
|
bool
|
Whether to L2-normalize output. |
False
|
seed
|
int
|
Random seed. |
42
|
Returns:
| Type | Description |
|---|---|
DifferentiableMolecularFingerprint
|
Configured DifferentiableMolecularFingerprint. |
create_ecfp4_operator¤
diffbio.operators.drug_discovery.fingerprint.create_ecfp4_operator
¤
create_ecfp4_operator(
n_bits: int = 2048,
differentiable: bool = True,
rngs: Rngs | None = None,
) -> CircularFingerprintOperator
Create ECFP4 (radius=2) fingerprint operator.
ECFP4 captures substructures within 4 bonds (radius 2).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_bits
|
int
|
Number of fingerprint bits (default: 2048). |
2048
|
differentiable
|
bool
|
Use learned hash functions (default: True). |
True
|
rngs
|
Rngs | None
|
Random number generators. |
None
|
Returns:
| Type | Description |
|---|---|
CircularFingerprintOperator
|
Configured CircularFingerprintOperator. |
create_ecfp6_operator¤
diffbio.operators.drug_discovery.fingerprint.create_ecfp6_operator
¤
create_ecfp6_operator(
n_bits: int = 2048,
differentiable: bool = True,
rngs: Rngs | None = None,
) -> CircularFingerprintOperator
Create ECFP6 (radius=3) fingerprint operator.
ECFP6 captures substructures within 6 bonds (radius 3).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_bits
|
int
|
Number of fingerprint bits (default: 2048). |
2048
|
differentiable
|
bool
|
Use learned hash functions (default: True). |
True
|
rngs
|
Rngs | None
|
Random number generators. |
None
|
Returns:
| Type | Description |
|---|---|
CircularFingerprintOperator
|
Configured CircularFingerprintOperator. |
create_fcfp4_operator¤
diffbio.operators.drug_discovery.fingerprint.create_fcfp4_operator
¤
create_fcfp4_operator(
n_bits: int = 2048,
differentiable: bool = True,
rngs: Rngs | None = None,
) -> CircularFingerprintOperator
Create FCFP4 (feature-based, radius=2) fingerprint operator.
FCFP4 uses pharmacophoric atom features instead of atomic properties. Better for finding molecules with similar biological activity.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_bits
|
int
|
Number of fingerprint bits (default: 2048). |
2048
|
differentiable
|
bool
|
Use learned hash functions (default: True). |
True
|
rngs
|
Rngs | None
|
Random number generators. |
None
|
Returns:
| Type | Description |
|---|---|
CircularFingerprintOperator
|
Configured CircularFingerprintOperator. |
create_similarity_operator¤
diffbio.operators.drug_discovery.similarity.create_similarity_operator
¤
create_similarity_operator(
similarity_type: str = "tanimoto",
temperature: float = 1.0,
) -> MolecularSimilarityOperator
Create a molecular similarity operator.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
similarity_type
|
str
|
Type of similarity ("tanimoto", "cosine", "dice"). |
'tanimoto'
|
temperature
|
float
|
Temperature parameter. |
1.0
|
Returns:
| Type | Description |
|---|---|
MolecularSimilarityOperator
|
Configured MolecularSimilarityOperator. |
create_admet_predictor¤
diffbio.operators.drug_discovery.admet_predictor.create_admet_predictor
¤
create_admet_predictor(
hidden_dim: int = 300,
num_layers: int = 3,
dropout_rate: float = 0.0,
seed: int = 42,
) -> ADMETPredictor
Create an ADMET predictor with standard configuration.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_dim
|
int
|
Hidden dimension for message passing. |
300
|
num_layers
|
int
|
Number of message passing steps. |
3
|
dropout_rate
|
float
|
Dropout rate. |
0.0
|
seed
|
int
|
Random seed. |
42
|
Returns:
| Type | Description |
|---|---|
ADMETPredictor
|
Configured ADMETPredictor. |
create_maccs_operator¤
diffbio.operators.drug_discovery.maccs_keys.create_maccs_operator
¤
create_maccs_operator(
differentiable: bool = True,
temperature: float = 1.0,
seed: int = 42,
) -> MACCSKeysOperator
Create a MACCS keys fingerprint operator.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
differentiable
|
bool
|
Use learned pattern matching. |
True
|
temperature
|
float
|
Temperature for soft matching. |
1.0
|
seed
|
int
|
Random seed. |
42
|
Returns:
| Type | Description |
|---|---|
MACCSKeysOperator
|
Configured MACCSKeysOperator. |
create_attentive_fp¤
diffbio.operators.drug_discovery.attentive_fp.create_attentive_fp
¤
create_attentive_fp(
hidden_dim: int = 200,
out_dim: int = 200,
num_layers: int = 2,
num_timesteps: int = 2,
dropout_rate: float = 0.0,
seed: int = 42,
) -> AttentiveFP
Create an AttentiveFP operator.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_dim
|
int
|
Hidden dimension for GNN layers. |
200
|
out_dim
|
int
|
Output fingerprint dimension. |
200
|
num_layers
|
int
|
Number of atom-level attention layers. |
2
|
num_timesteps
|
int
|
Number of molecule-level GRU iterations. |
2
|
dropout_rate
|
float
|
Dropout rate. |
0.0
|
seed
|
int
|
Random seed. |
42
|
Returns:
| Type | Description |
|---|---|
AttentiveFP
|
Configured AttentiveFP. |
Similarity Functions¤
tanimoto_similarity¤
diffbio.operators.drug_discovery.similarity.tanimoto_similarity
¤
Compute differentiable Tanimoto similarity.
For continuous vectors, uses the generalized Tanimoto formula: T(a, b) = (a · b) / (|a|² + |b|² - a · b)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
a
|
ndarray
|
First fingerprint vector. |
required |
b
|
ndarray
|
Second fingerprint vector. |
required |
eps
|
float
|
Small constant for numerical stability. |
1e-08
|
Returns:
| Type | Description |
|---|---|
ndarray
|
Similarity score in [0, 1]. |
cosine_similarity¤
diffbio.operators.drug_discovery.similarity.cosine_similarity
¤
Compute cosine similarity.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
a
|
ndarray
|
First vector. |
required |
b
|
ndarray
|
Second vector. |
required |
eps
|
float
|
Small constant for numerical stability. |
1e-08
|
Returns:
| Type | Description |
|---|---|
ndarray
|
Similarity score in [-1, 1]. |
dice_similarity¤
diffbio.operators.drug_discovery.similarity.dice_similarity
¤
Compute Dice similarity coefficient.
For continuous vectors
Dice(a, b) = 2 * (a · b) / (|a|² + |b|²)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
a
|
ndarray
|
First vector. |
required |
b
|
ndarray
|
Second vector. |
required |
eps
|
float
|
Small constant for numerical stability. |
1e-08
|
Returns:
| Type | Description |
|---|---|
ndarray
|
Similarity score in [0, 1]. |
Graph Conversion Utilities¤
smiles_to_graph¤
diffbio.operators.drug_discovery.primitives.smiles_to_graph
¤
Convert a SMILES string to a molecular graph.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
smiles
|
str
|
SMILES string representing a molecule. |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
Dictionary containing: - node_features: (num_atoms, num_features) atom feature matrix - adjacency: (num_atoms, num_atoms) adjacency matrix - edge_features: (num_atoms, num_atoms, num_edge_features) bond features - num_nodes: number of atoms |
Raises:
| Type | Description |
|---|---|
ValueError
|
If SMILES string is invalid. |
batch_smiles_to_graphs¤
diffbio.operators.drug_discovery.primitives.batch_smiles_to_graphs
¤
Convert a batch of SMILES strings to padded graph tensors.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
smiles_list
|
list[str]
|
List of SMILES strings. |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
Dictionary containing: - node_features: (batch_size, max_nodes, num_features) - adjacency: (batch_size, max_nodes, max_nodes) - edge_features: (batch_size, max_nodes, max_nodes, num_edge_features) - node_mask: (batch_size, max_nodes) mask for valid nodes |
AtomFeatureConfig¤
diffbio.operators.drug_discovery.primitives.AtomFeatureConfig
dataclass
¤
AtomFeatureConfig(
num_atom_types: int = 12,
max_degree: int = 6,
charge_range: tuple[int, int] = (-2, 2),
num_hybridization_types: int = 4,
max_num_hydrogens: int = 4,
)
Configuration for atom feature extraction.
The default configuration produces 34 features:
- Atom type: 12 dimensions (C, N, O, S, F, Cl, Br, I, P, Si, B, Other)
- Degree: 7 dimensions (0-6)
- Formal charge: 5 dimensions (-2 to +2)
- Hybridization: 4 dimensions (SP, SP2, SP3, SP3D)
- Aromaticity: 1 dimension (binary)
- Num hydrogens: 5 dimensions (0-4)
Constants¤
DEFAULT_ATOM_FEATURES¤
diffbio.operators.drug_discovery.primitives.DEFAULT_ATOM_FEATURES
module-attribute
¤
DEFAULT_ATOM_CONFIG¤
diffbio.operators.drug_discovery.primitives.DEFAULT_ATOM_CONFIG
module-attribute
¤
DEFAULT_ATOM_CONFIG = AtomFeatureConfig()
Usage Examples¤
Property Prediction¤
from diffbio.operators.drug_discovery import (
MolecularPropertyPredictor,
MolecularPropertyConfig,
smiles_to_graph,
DEFAULT_ATOM_FEATURES,
)
from flax import nnx
# Create predictor
config = MolecularPropertyConfig(
hidden_dim=64,
num_message_passing_steps=3,
num_output_tasks=1,
in_features=DEFAULT_ATOM_FEATURES,
)
predictor = MolecularPropertyPredictor(config, rngs=nnx.Rngs(42))
# Convert SMILES to graph
node_features, adjacency, edge_features = smiles_to_graph("CCO")
# Predict properties
data = {
"node_features": node_features,
"adjacency": adjacency,
"edge_features": edge_features,
}
result, _, _ = predictor.apply(data, {}, None)
predictions = result["predictions"] # (1,)
Fingerprint Computation¤
from diffbio.operators.drug_discovery import (
DifferentiableMolecularFingerprint,
MolecularFingerprintConfig,
smiles_to_graph,
DEFAULT_ATOM_FEATURES,
)
from flax import nnx
# Create fingerprint operator
config = MolecularFingerprintConfig(
fingerprint_dim=256,
hidden_dim=128,
num_layers=3,
in_features=DEFAULT_ATOM_FEATURES,
normalize=True,
)
fp_op = DifferentiableMolecularFingerprint(config, rngs=nnx.Rngs(42))
# Compute fingerprint
node_features, adjacency, _ = smiles_to_graph("c1ccccc1")
data = {"node_features": node_features, "adjacency": adjacency}
result, _, _ = fp_op.apply(data, {}, None)
fingerprint = result["fingerprint"] # (256,)
Circular Fingerprint (ECFP/Morgan)¤
from diffbio.operators.drug_discovery import (
CircularFingerprintOperator,
CircularFingerprintConfig,
create_ecfp4_operator,
create_ecfp6_operator,
create_fcfp4_operator,
smiles_to_graph,
DEFAULT_ATOM_FEATURES,
)
from flax import nnx
# Using factory function for ECFP4-like fingerprints
ecfp4_op = create_ecfp4_operator(n_bits=2048)
# Or with full configuration
config = CircularFingerprintConfig(
radius=2, # ECFP4 (2 * radius = 4)
n_bits=2048, # Fingerprint dimension
use_chirality=False, # Include stereochemistry
use_features=False, # Use FCFP instead of ECFP
differentiable=True, # Use learned hash functions
hash_hidden_dim=128, # Hidden dim for hash network
temperature=1.0, # Softmax temperature
in_features=DEFAULT_ATOM_FEATURES,
)
fp_op = CircularFingerprintOperator(config, rngs=nnx.Rngs(42))
# Compute fingerprint from molecular graph
node_features, adjacency, _ = smiles_to_graph("c1ccccc1")
data = {
"node_features": node_features,
"adjacency": adjacency,
}
result, _, _ = fp_op.apply(data, {}, None)
fingerprint = result["fingerprint"] # (2048,)
# Or compute directly from SMILES (RDKit mode)
config_rdkit = CircularFingerprintConfig(
radius=2,
n_bits=2048,
differentiable=False, # Use RDKit exact fingerprint
)
fp_op_rdkit = CircularFingerprintOperator(config_rdkit)
data = {"smiles": "c1ccccc1"}
result, _, _ = fp_op_rdkit.apply(data, {}, None)
Similarity Computation¤
from diffbio.operators.drug_discovery import (
MolecularSimilarityOperator,
MolecularSimilarityConfig,
tanimoto_similarity,
)
from flax import nnx
import jax.numpy as jnp
# Using operator
config = MolecularSimilarityConfig(similarity_type="tanimoto")
sim_op = MolecularSimilarityOperator(config, rngs=nnx.Rngs(42))
fp1 = jnp.array([1.0, 0.5, 0.0, 0.8])
fp2 = jnp.array([0.9, 0.6, 0.1, 0.7])
data = {"fingerprint_a": fp1, "fingerprint_b": fp2}
result, _, _ = sim_op.apply(data, {}, None)
similarity = result["similarity"]
# Using standalone function
sim = tanimoto_similarity(fp1, fp2)
Batch Processing¤
from diffbio.operators.drug_discovery import batch_smiles_to_graphs
smiles_list = ["CCO", "CC(=O)O", "c1ccccc1"]
node_features, adjacency, edge_features, masks = batch_smiles_to_graphs(
smiles_list,
max_atoms=20,
)
# node_features: (3, 20, 34)
# adjacency: (3, 20, 20)
# edge_features: (3, 20, 20, 4)
# masks: (3, 20)
Gradient Computation¤
import jax
from flax import nnx
from diffbio.operators.drug_discovery import (
create_property_predictor,
smiles_to_graph,
DEFAULT_ATOM_FEATURES,
)
predictor = create_property_predictor(
hidden_dim=32,
num_layers=2,
num_tasks=1,
)
node_features, adjacency, edge_features = smiles_to_graph("CCO")
data = {
"node_features": node_features,
"adjacency": adjacency,
"edge_features": edge_features,
}
def loss_fn(model, data):
result, _, _ = model.apply(data, {}, None)
return result["predictions"].sum()
# Compute gradients with nnx.grad
grads = nnx.grad(loss_fn)(predictor, data)
Input Specifications¤
MolecularPropertyPredictor¤
| Key | Shape | Type | Description |
|---|---|---|---|
node_features |
(n, in_features) | float32 | Atom feature vectors |
adjacency |
(n, n) | float32 | Adjacency matrix |
edge_features |
(n, n, num_edge_features) | float32 | Optional bond features |
node_mask |
(n,) | float32 | Optional mask for valid atoms |
DifferentiableMolecularFingerprint¤
| Key | Shape | Type | Description |
|---|---|---|---|
node_features |
(n, in_features) | float32 | Atom feature vectors |
adjacency |
(n, n) | float32 | Adjacency matrix |
edge_features |
(n, n, num_edge_features) | float32 | Optional bond features |
node_mask |
(n,) | float32 | Optional mask for valid atoms |
CircularFingerprintOperator¤
Differentiable mode (differentiable=True):
| Key | Shape | Type | Description |
|---|---|---|---|
node_features |
(n, in_features) | float32 | Atom feature vectors |
adjacency |
(n, n) | float32 | Adjacency matrix |
node_mask |
(n,) | float32 | Optional mask for valid atoms |
RDKit mode (differentiable=False):
| Key | Type | Description |
|---|---|---|
smiles |
str | SMILES string |
MolecularSimilarityOperator¤
| Key | Shape | Type | Description |
|---|---|---|---|
fingerprint_a |
(dim,) | float32 | First fingerprint vector |
fingerprint_b |
(dim,) | float32 | Second fingerprint vector |
Output Specifications¤
MolecularPropertyPredictor¤
| Key | Shape | Type | Description |
|---|---|---|---|
predictions |
(num_tasks,) | float32 | Property predictions |
graph_representation |
(hidden_dim,) | float32 | Graph-level embedding |
DifferentiableMolecularFingerprint¤
| Key | Shape | Type | Description |
|---|---|---|---|
fingerprint |
(fingerprint_dim,) | float32 | Molecular fingerprint |
CircularFingerprintOperator¤
| Key | Shape | Type | Description |
|---|---|---|---|
fingerprint |
(n_bits,) | float32 | Circular fingerprint (soft probabilities in differentiable mode, binary in RDKit mode) |
MolecularSimilarityOperator¤
| Key | Shape | Type | Description |
|---|---|---|---|
similarity |
() | float32 | Similarity score |
Atom Feature Dimensions¤
| Feature | Dimensions | Description |
|---|---|---|
| Atom type | 10 | C, N, O, S, F, Cl, Br, I, P, other |
| Degree | 6 | 0-5+ neighbors |
| Formal charge | 5 | -2 to +2 |
| Hybridization | 5 | SP, SP2, SP3, SP3D, SP3D2 |
| Aromaticity | 1 | Binary |
| Hydrogens | 5 | 0-4+ |
| In ring | 1 | Binary |
| Chiral | 1 | Binary |
| Total | 34 | DEFAULT_ATOM_FEATURES |