Skip to content

Drug Discovery Operators API¤

Differentiable operators for molecular property prediction, fingerprint computation, and similarity scoring.

MolecularPropertyPredictor¤

diffbio.operators.drug_discovery.property_predictor.MolecularPropertyPredictor ¤

MolecularPropertyPredictor(
    config: MolecularPropertyConfig,
    *,
    rngs: Rngs | None = None,
)

Bases: OperatorModule

ChemProp-style molecular property predictor.

Implements a directed message passing neural network (D-MPNN) for predicting molecular properties from graph representations.

The architecture consists of: 1. Message passing layers to compute atom representations 2. Graph-level readout via sum pooling 3. Feed-forward network for property prediction

Example
config = MolecularPropertyConfig(hidden_dim=64, num_output_tasks=3)
predictor = MolecularPropertyPredictor(config, rngs=nnx.Rngs(42))
data = {
    "node_features": node_features,
    "adjacency": adjacency,
    "node_mask": mask,
}
result, state, meta = predictor.apply(data, {}, None)
predictions = result["predictions"]  # shape: (3,)

Parameters:

Name Type Description Default
config MolecularPropertyConfig

Predictor configuration.

required
rngs Rngs | None

Flax NNX random number generators.

None

apply ¤

apply(
    data: dict[str, Any],
    state: dict[str, Any],
    metadata: dict[str, Any] | None,
    random_params: Any = None,
    stats: dict[str, Any] | None = None,
) -> tuple[
    dict[str, Any], dict[str, Any], dict[str, Any] | None
]

Predict molecular properties from graph representation.

Parameters:

Name Type Description Default
data dict[str, Any]

Input data containing: - node_features: (num_nodes, num_features) atom features - adjacency: (num_nodes, num_nodes) adjacency matrix - edge_features: Optional (num_nodes, num_nodes, num_edge_features) - node_mask: (num_nodes,) mask for valid nodes

required
state dict[str, Any]

Per-element state (passed through).

required
metadata dict[str, Any] | None

Optional metadata.

required
random_params Any

Unused random parameters.

None
stats dict[str, Any] | None

Optional statistics dictionary.

None

Returns:

Type Description
tuple[dict[str, Any], dict[str, Any], dict[str, Any] | None]

Tuple of: - data with added "predictions" key - unchanged state - unchanged metadata

MolecularPropertyConfig¤

diffbio.operators.drug_discovery.property_predictor.MolecularPropertyConfig dataclass ¤

MolecularPropertyConfig(
    hidden_dim: int = 300,
    num_message_passing_steps: int = 3,
    num_output_tasks: int = 1,
    dropout_rate: float = 0.0,
    in_features: int = 4,
    num_edge_features: int = 4,
)

Bases: OperatorConfig

Configuration for molecular property predictor.

Attributes:

Name Type Description
hidden_dim int

Hidden dimension for message passing layers.

num_message_passing_steps int

Number of message passing iterations.

num_output_tasks int

Number of prediction tasks (multi-task learning).

dropout_rate float

Dropout rate for regularization.

in_features int

Number of input node features (default: DEFAULT_ATOM_FEATURES=34).

num_edge_features int

Number of edge/bond features.

ADMETPredictor¤

diffbio.operators.drug_discovery.admet_predictor.ADMETPredictor ¤

ADMETPredictor(
    config: ADMETConfig,
    *,
    rngs: Rngs | None = None,
    name: str | None = None,
)

Bases: OperatorModule

Multi-task ADMET property predictor.

Implements a ChemProp-style directed message passing neural network for predicting multiple ADMET properties simultaneously. The architecture uses a shared molecular encoder with task-specific prediction heads.

Architecture
  1. Message passing encoder (D-MPNN style)
  2. Graph-level readout via sum pooling
  3. Shared feed-forward layers
  4. Task-specific output heads
The 22 standard TDC ADMET endpoints cover
  • Absorption: Caco2, HIA, Pgp, Bioavailability, Lipophilicity, Solubility
  • Distribution: BBB, PPBR, VDss
  • Metabolism: CYP enzymes (2C9, 2D6, 3A4) inhibition and substrate
  • Excretion: Half-life, Hepatocyte clearance, Microsome clearance
  • Toxicity: LD50, hERG, AMES, DILI
Example
config = ADMETConfig(hidden_dim=256, num_tasks=22)
predictor = ADMETPredictor(config, rngs=nnx.Rngs(42))
data = {"node_features": nodes, "adjacency": adj, "node_mask": mask}
result, _, _ = predictor.apply(data, {}, None)
predictions = result["predictions"]  # shape: (22,)
References
  • https://tdcommons.ai/benchmark/admet_group/overview/
  • Yang et al. "Analyzing Learned Molecular Representations" JCIM 2019

Parameters:

Name Type Description Default
config ADMETConfig

ADMET configuration.

required
rngs Rngs | None

Flax NNX random number generators.

None
name str | None

Optional name for the operator.

None

apply ¤

apply(
    data: dict[str, Any],
    state: dict[str, Any],
    metadata: dict[str, Any] | None,
    random_params: Any = None,
    stats: dict[str, Any] | None = None,
) -> tuple[
    dict[str, Any], dict[str, Any], dict[str, Any] | None
]

Predict ADMET properties from molecular graph.

Parameters:

Name Type Description Default
data dict[str, Any]

Input data containing: - node_features: (num_nodes, num_features) atom features - adjacency: (num_nodes, num_nodes) adjacency matrix - edge_features: Optional (num_nodes, num_nodes, num_edge_features) - node_mask: (num_nodes,) mask for valid nodes

required
state dict[str, Any]

Per-element state (passed through).

required
metadata dict[str, Any] | None

Optional metadata.

required
random_params Any

Unused random parameters.

None
stats dict[str, Any] | None

Optional statistics dictionary.

None

Returns:

Type Description
tuple[dict[str, Any], dict[str, Any], dict[str, Any] | None]

Tuple of: - data with added "predictions" and "task_predictions" keys - unchanged state - unchanged metadata

ADMETConfig¤

diffbio.operators.drug_discovery.admet_predictor.ADMETConfig dataclass ¤

ADMETConfig(
    hidden_dim: int = 300,
    num_message_passing_steps: int = 3,
    num_tasks: int = 22,
    dropout_rate: float = 0.0,
    in_features: int = 4,
    num_edge_features: int = 4,
    ffn_hidden_dim: int | None = None,
    ffn_num_layers: int = 2,
    apply_task_activations: bool = False,
)

Bases: OperatorConfig

Configuration for ADMET property predictor.

Attributes:

Name Type Description
hidden_dim int

Hidden dimension for message passing (default: 300).

num_message_passing_steps int

Number of D-MPNN iterations (default: 3).

num_tasks int

Number of ADMET prediction tasks (default: 22).

dropout_rate float

Dropout rate for regularization (default: 0.0).

in_features int

Number of input node features (default: 4).

num_edge_features int

Number of edge features (default: 4).

ffn_hidden_dim int | None

FFN hidden dimension (default: same as hidden_dim).

ffn_num_layers int

Number of FFN layers (default: 2).

apply_task_activations bool

Apply sigmoid for classification tasks (default: False).

MACCSKeysOperator¤

diffbio.operators.drug_discovery.maccs_keys.MACCSKeysOperator ¤

MACCSKeysOperator(
    config: MACCSKeysConfig, *, rngs: Rngs | None = None
)

Bases: OperatorModule

Differentiable MACCS structural keys fingerprint operator.

For differentiable=True: Uses message passing and learned pattern detectors to approximate MACCS key detection. Each of the 166 keys is represented by a learned pattern matching network that outputs a soft presence score.

For differentiable=False: Would use RDKit's exact MACCS implementation (not differentiable).

The differentiable version enables gradient flow for end-to-end optimization while approximating the structural pattern detection of traditional MACCS keys.

MACCS keys encode various structural features:

- Atom types (C, N, O, S, halides, etc.)
- Functional groups (carbonyl, hydroxyl, amine, etc.)
- Ring systems (aromatic, aliphatic)
- Bond patterns and connectivity
Example
config = MACCSKeysConfig(temperature=1.0)
op = MACCSKeysOperator(config, rngs=nnx.Rngs(42))
data = {"node_features": nodes, "adjacency": adj}
result, _, _ = op.apply(data, {}, None)
fingerprint = result["fingerprint"]  # shape: (166,)
References
  • Durant et al. "Reoptimization of MDL Keys" JCIM 2002

Parameters:

Name Type Description Default
config MACCSKeysConfig

MACCS keys configuration.

required
rngs Rngs | None

Flax NNX random number generators.

None

apply ¤

apply(
    data: dict[str, Any],
    state: dict[str, Any],
    metadata: dict[str, Any] | None,
    random_params: Any = None,
    stats: dict[str, Any] | None = None,
) -> tuple[
    dict[str, Any], dict[str, Any], dict[str, Any] | None
]

Compute MACCS keys fingerprint.

Parameters:

Name Type Description Default
data dict[str, Any]

Input data containing: For differentiable=True: - node_features: (num_nodes, num_features) atom features - adjacency: (num_nodes, num_nodes) adjacency matrix - node_mask: (num_nodes,) optional mask for valid nodes For differentiable=False: - smiles: SMILES string

required
state dict[str, Any]

Per-element state (passed through).

required
metadata dict[str, Any] | None

Optional metadata.

required
random_params Any

Unused random parameters.

None
stats dict[str, Any] | None

Optional statistics dictionary.

None

Returns:

Type Description
tuple[dict[str, Any], dict[str, Any], dict[str, Any] | None]

Tuple of: - data with added "fingerprint" key - unchanged state - unchanged metadata

MACCSKeysConfig¤

diffbio.operators.drug_discovery.maccs_keys.MACCSKeysConfig dataclass ¤

MACCSKeysConfig(
    n_bits: int = 166,
    differentiable: bool = True,
    temperature: float = 1.0,
    hidden_dim: int = 64,
    num_layers: int = 2,
    in_features: int = 4,
)

Bases: OperatorConfig

Configuration for MACCS keys fingerprint operator.

Attributes:

Name Type Description
n_bits int

Number of fingerprint bits (default: 166 for standard MACCS).

differentiable bool

Use learned pattern matching (default: True).

temperature float

Temperature for soft bit assignment (default: 1.0).

hidden_dim int

Hidden dimension for pattern networks (default: 64).

num_layers int

Number of message passing layers (default: 2).

in_features int

Number of input node features (default: 4).

AttentiveFP¤

diffbio.operators.drug_discovery.attentive_fp.AttentiveFP ¤

AttentiveFP(
    config: AttentiveFPConfig, *, rngs: Rngs | None = None
)

Bases: OperatorModule

AttentiveFP: Attention-based molecular fingerprint.

Implements the AttentiveFP architecture with
  1. Atom-level attention layers with GRU refinement
  2. Molecule-level aggregation with attention and GRU
  3. Final projection to fingerprint dimension

The model provides interpretable attention weights that indicate which atoms contribute most to the molecular representation.

Example
config = AttentiveFPConfig(hidden_dim=128, out_dim=256)
afp = AttentiveFP(config, rngs=nnx.Rngs(42))
data = {"node_features": nodes, "adjacency": adj, "edge_features": edges}
result, _, _ = afp.apply(data, {}, None)
fingerprint = result["fingerprint"]  # (256,)
attn = result["attention_weights"]  # interpretability
References
  • Xiong et al. JCIM 2019

Parameters:

Name Type Description Default
config AttentiveFPConfig

AttentiveFP configuration.

required
rngs Rngs | None

Flax NNX random number generators.

None

apply ¤

apply(
    data: dict[str, Any],
    state: dict[str, Any],
    metadata: dict[str, Any] | None,
    random_params: Any = None,
    stats: dict[str, Any] | None = None,
) -> tuple[
    dict[str, Any], dict[str, Any], dict[str, Any] | None
]

Compute AttentiveFP molecular fingerprint.

Parameters:

Name Type Description Default
data dict[str, Any]

Input data containing: - node_features: (num_nodes, in_features) atom features - adjacency: (num_nodes, num_nodes) adjacency matrix - edge_features: Optional (num_nodes, num_nodes, edge_dim) - node_mask: (num_nodes,) optional mask for valid nodes

required
state dict[str, Any]

Per-element state (passed through).

required
metadata dict[str, Any] | None

Optional metadata.

required
random_params Any

Unused random parameters.

None
stats dict[str, Any] | None

Optional statistics dictionary.

None

Returns:

Type Description
tuple[dict[str, Any], dict[str, Any], dict[str, Any] | None]

Tuple of: - data with "fingerprint" and "attention_weights" keys - unchanged state - unchanged metadata

AttentiveFPConfig¤

diffbio.operators.drug_discovery.attentive_fp.AttentiveFPConfig dataclass ¤

AttentiveFPConfig(
    in_features: int = 39,
    edge_dim: int = 10,
    negative_slope: float = 0.2,
    hidden_dim: int = 200,
    out_dim: int = 200,
    num_layers: int = 2,
    num_timesteps: int = 2,
    dropout_rate: float = 0.0,
)

Bases: _AttentiveFPArchitectureConfig, _AttentiveFPInputConfig, OperatorConfig

Configuration for AttentiveFP operator.

in_features class-attribute instance-attribute ¤

in_features: int = 39

edge_dim class-attribute instance-attribute ¤

edge_dim: int = 10

negative_slope class-attribute instance-attribute ¤

negative_slope: float = 0.2

hidden_dim class-attribute instance-attribute ¤

hidden_dim: int = 200

out_dim class-attribute instance-attribute ¤

out_dim: int = 200

num_layers class-attribute instance-attribute ¤

num_layers: int = 2

num_timesteps class-attribute instance-attribute ¤

num_timesteps: int = 2

dropout_rate class-attribute instance-attribute ¤

dropout_rate: float = 0.0

DifferentiableMolecularFingerprint¤

diffbio.operators.drug_discovery.fingerprint.DifferentiableMolecularFingerprint ¤

DifferentiableMolecularFingerprint(
    config: MolecularFingerprintConfig,
    *,
    rngs: Rngs | None = None,
    name: str | None = None,
)

Bases: OperatorModule

Neural graph fingerprint operator.

Computes learned molecular fingerprints using graph neural networks. Unlike traditional fingerprints (e.g., ECFP/Morgan), these are fully differentiable and can be optimized for specific tasks.

The fingerprint is computed by: 1. Message passing to compute atom representations 2. Sum pooling to get graph-level representation 3. Linear projection to fingerprint dimension 4. Optional L2 normalization

Example
config = MolecularFingerprintConfig(fingerprint_dim=128)
fp_op = DifferentiableMolecularFingerprint(config, rngs=nnx.Rngs(42))
data = {"node_features": nodes, "adjacency": adj, "node_mask": mask}
result, _, _ = fp_op.apply(data, {}, None)
fingerprint = result["fingerprint"]  # shape: (128,)

Parameters:

Name Type Description Default
config MolecularFingerprintConfig

Fingerprint configuration.

required
rngs Rngs | None

Flax NNX random number generators.

None
name str | None

Optional name for the operator.

None

apply ¤

apply(
    data: dict[str, Any],
    state: dict[str, Any],
    metadata: dict[str, Any] | None,
    random_params: Any = None,
    stats: dict[str, Any] | None = None,
) -> tuple[
    dict[str, Any], dict[str, Any], dict[str, Any] | None
]

Compute molecular fingerprint.

Parameters:

Name Type Description Default
data dict[str, Any]

Input data containing: - node_features: (num_nodes, num_features) atom features - adjacency: (num_nodes, num_nodes) adjacency matrix - node_mask: (num_nodes,) mask for valid nodes

required
state dict[str, Any]

Per-element state (passed through).

required
metadata dict[str, Any] | None

Optional metadata.

required
random_params Any

Unused random parameters.

None
stats dict[str, Any] | None

Optional statistics dictionary.

None

Returns:

Type Description
tuple[dict[str, Any], dict[str, Any], dict[str, Any] | None]

Tuple of: - data with added "fingerprint" key - unchanged state - unchanged metadata

MolecularFingerprintConfig¤

diffbio.operators.drug_discovery.fingerprint.MolecularFingerprintConfig dataclass ¤

MolecularFingerprintConfig(
    fingerprint_dim: int = 256,
    hidden_dim: int = 128,
    num_layers: int = 3,
    in_features: int = 4,
    normalize: bool = False,
)

Bases: OperatorConfig

Configuration for molecular fingerprint operator.

Attributes:

Name Type Description
fingerprint_dim int

Dimension of output fingerprint vector.

hidden_dim int

Hidden dimension for graph convolutions.

num_layers int

Number of graph convolution layers.

in_features int

Number of input node features (default: DEFAULT_ATOM_FEATURES=34).

normalize bool

Whether to L2-normalize the fingerprint.

CircularFingerprintOperator¤

diffbio.operators.drug_discovery.fingerprint.CircularFingerprintOperator ¤

CircularFingerprintOperator(
    config: CircularFingerprintConfig,
    *,
    rngs: Rngs | None = None,
)

Bases: OperatorModule

Differentiable circular fingerprints (ECFP/Morgan).

For differentiable=True: Uses message passing to aggregate substructure information, then learned "soft hash" functions for bit assignment. Gradients flow through the entire computation.

For differentiable=False: Wraps RDKit implementation for exact ECFP. No gradient flow (useful for inference/comparison).

The differentiable version approximates ECFP behavior while enabling end-to-end optimization of the fingerprint representation.

Example
config = CircularFingerprintConfig(radius=2, n_bits=1024)
fp_op = CircularFingerprintOperator(config, rngs=nnx.Rngs(0))
data = {"node_features": node_feats, "adjacency": adj}
result, state, meta = fp_op.apply(data, {}, None)
fingerprint = result["fingerprint"]  # Shape: (n_bits,)
References

Rogers, David, and Mathew Hahn. "Extended-connectivity fingerprints." Journal of chemical information and modeling 50.5 (2010): 742-754.

Parameters:

Name Type Description Default
config CircularFingerprintConfig

Circular fingerprint configuration.

required
rngs Rngs | None

Flax NNX random number generators.

None

apply ¤

apply(
    data: dict[str, Any],
    state: dict[str, Any],
    metadata: dict[str, Any] | None,
    random_params: Any = None,
    stats: dict[str, Any] | None = None,
) -> tuple[
    dict[str, Any], dict[str, Any], dict[str, Any] | None
]

Compute circular fingerprint.

Parameters:

Name Type Description Default
data dict[str, Any]

Input data containing either: For differentiable=True: - node_features: (num_nodes, num_features) atom features - adjacency: (num_nodes, num_nodes) adjacency matrix - node_mask: (num_nodes,) optional mask for valid nodes For differentiable=False: - smiles: SMILES string

required
state dict[str, Any]

Per-element state (passed through).

required
metadata dict[str, Any] | None

Optional metadata.

required
random_params Any

Unused random parameters.

None
stats dict[str, Any] | None

Optional statistics dictionary.

None

Returns:

Type Description
tuple[dict[str, Any], dict[str, Any], dict[str, Any] | None]

Tuple of: - data with added "fingerprint" key - unchanged state - unchanged metadata

CircularFingerprintConfig¤

diffbio.operators.drug_discovery.fingerprint.CircularFingerprintConfig dataclass ¤

CircularFingerprintConfig(
    radius: int = 2,
    n_bits: int = 2048,
    use_chirality: bool = False,
    use_bond_types: bool = True,
    use_features: bool = False,
    differentiable: bool = True,
    hash_hidden_dim: int = 128,
    temperature: float = 1.0,
    in_features: int = 4,
)

Bases: OperatorConfig

Configuration for circular fingerprint operator (ECFP/Morgan).

Attributes:

Name Type Description
radius int

Fingerprint radius. ECFP4 = radius 2, ECFP6 = radius 3.

n_bits int

Number of bits in fingerprint (default: 2048).

use_chirality bool

Include chirality in fingerprint (default: False).

use_bond_types bool

Include bond type information (default: True).

use_features bool

Use pharmacophoric features (FCFP variant, default: False).

differentiable bool

Use learned hash functions for gradients (default: True).

hash_hidden_dim int

Hidden dimension for hash network (default: 128).

temperature float

Temperature for soft bit assignment (default: 1.0).

in_features int

Number of input node features (default: 4).

MolecularSimilarityOperator¤

diffbio.operators.drug_discovery.similarity.MolecularSimilarityOperator ¤

MolecularSimilarityOperator(
    config: MolecularSimilarityConfig,
    *,
    rngs: Rngs | None = None,
)

Bases: OperatorModule

Differentiable molecular similarity operator.

Computes similarity between molecular fingerprints using various differentiable metrics. Supports Tanimoto, cosine, and Dice similarity.

Example
config = MolecularSimilarityConfig(similarity_type="tanimoto")
sim_op = MolecularSimilarityOperator(config, rngs=nnx.Rngs(42))
data = {"fingerprint_a": fp1, "fingerprint_b": fp2}
result, _, _ = sim_op.apply(data, {}, None)
similarity = result["similarity"]  # scalar in [0, 1]

Parameters:

Name Type Description Default
config MolecularSimilarityConfig

Similarity configuration.

required
rngs Rngs | None

Flax NNX random number generators.

None

apply ¤

apply(
    data: dict[str, Any],
    state: dict[str, Any],
    metadata: dict[str, Any] | None,
    random_params: Any = None,
    stats: dict[str, Any] | None = None,
) -> tuple[
    dict[str, Any], dict[str, Any], dict[str, Any] | None
]

Compute similarity between two fingerprints.

Parameters:

Name Type Description Default
data dict[str, Any]

Input data containing: - fingerprint_a: First fingerprint vector - fingerprint_b: Second fingerprint vector

required
state dict[str, Any]

Per-element state (passed through).

required
metadata dict[str, Any] | None

Optional metadata.

required
random_params Any

Unused random parameters.

None
stats dict[str, Any] | None

Optional statistics dictionary.

None

Returns:

Type Description
tuple[dict[str, Any], dict[str, Any], dict[str, Any] | None]

Tuple of: - data with added "similarity" key - unchanged state - unchanged metadata

MolecularSimilarityConfig¤

diffbio.operators.drug_discovery.similarity.MolecularSimilarityConfig dataclass ¤

MolecularSimilarityConfig(
    similarity_type: str = "tanimoto",
    temperature: float = 1.0,
)

Bases: OperatorConfig

Configuration for molecular similarity operator.

Attributes:

Name Type Description
similarity_type str

Type of similarity metric ("tanimoto", "cosine", "dice").

temperature float

Temperature for soft similarity (higher = sharper).

Message Passing Layers¤

MessagePassingLayer¤

diffbio.operators.drug_discovery.message_passing.MessagePassingLayer ¤

MessagePassingLayer(
    hidden_dim: int,
    in_features: int = 4,
    num_edge_features: int = 4,
    *,
    rngs: Rngs,
)

Bases: Module

Directed message passing layer for molecular graphs.

Implements the D-MPNN message passing scheme where messages are passed along directed edges. Each node aggregates messages from its neighbors and updates its representation.

Attributes:

Name Type Description
hidden_dim

Dimension of hidden node representations.

in_features

Number of input node features.

num_edge_features

Number of edge features (default 4 for bond types).

Parameters:

Name Type Description Default
hidden_dim int

Dimension of hidden representations.

required
in_features int

Number of input node features (default 4 for tests).

4
num_edge_features int

Number of edge/bond features.

4
rngs Rngs

Flax NNX random number generators.

required

__call__ ¤

__call__(
    node_features: ndarray,
    adjacency: ndarray,
    edge_features: ndarray | None = None,
) -> ndarray

Perform one step of message passing.

Parameters:

Name Type Description Default
node_features ndarray

Node features of shape (num_nodes, in_features).

required
adjacency ndarray

Adjacency matrix of shape (num_nodes, num_nodes).

required
edge_features ndarray | None

Optional edge features of shape (num_nodes, num_nodes, num_edge_features).

None

Returns:

Type Description
ndarray

Updated node features of shape (num_nodes, hidden_dim).

StackedMessagePassing¤

diffbio.operators.drug_discovery.message_passing.StackedMessagePassing ¤

StackedMessagePassing(
    hidden_dim: int,
    num_layers: int,
    in_features: int = 4,
    num_edge_features: int = 4,
    *,
    rngs: Rngs,
)

Bases: Module

Stack of message passing layers.

Applies multiple rounds of message passing to capture higher-order neighborhood information.

Parameters:

Name Type Description Default
hidden_dim int

Hidden dimension for all layers.

required
num_layers int

Number of message passing iterations.

required
in_features int

Number of input node features (default 4 for tests).

4
num_edge_features int

Number of edge features.

4
rngs Rngs

Flax NNX random number generators.

required

__call__ ¤

__call__(
    node_features: ndarray,
    adjacency: ndarray,
    edge_features: ndarray | None = None,
) -> ndarray

Apply multiple rounds of message passing.

Parameters:

Name Type Description Default
node_features ndarray

Initial node features.

required
adjacency ndarray

Adjacency matrix.

required
edge_features ndarray | None

Optional edge features.

None

Returns:

Type Description
ndarray

Final node representations.

Factory Functions¤

create_property_predictor¤

diffbio.operators.drug_discovery.property_predictor.create_property_predictor ¤

create_property_predictor(
    hidden_dim: int = 300,
    num_layers: int = 3,
    num_tasks: int = 1,
    dropout_rate: float = 0.0,
    seed: int = 42,
) -> MolecularPropertyPredictor

Create a molecular property predictor.

Parameters:

Name Type Description Default
hidden_dim int

Hidden dimension for message passing.

300
num_layers int

Number of message passing steps.

3
num_tasks int

Number of prediction tasks.

1
dropout_rate float

Dropout rate.

0.0
seed int

Random seed.

42

Returns:

Type Description
MolecularPropertyPredictor

Configured MolecularPropertyPredictor.

create_fingerprint_operator¤

diffbio.operators.drug_discovery.fingerprint.create_fingerprint_operator ¤

create_fingerprint_operator(
    fingerprint_dim: int = 256,
    num_layers: int = 3,
    normalize: bool = False,
    seed: int = 42,
) -> DifferentiableMolecularFingerprint

Create a molecular fingerprint operator.

Parameters:

Name Type Description Default
fingerprint_dim int

Output fingerprint dimension.

256
num_layers int

Number of message passing layers.

3
normalize bool

Whether to L2-normalize output.

False
seed int

Random seed.

42

Returns:

Type Description
DifferentiableMolecularFingerprint

Configured DifferentiableMolecularFingerprint.

create_ecfp4_operator¤

diffbio.operators.drug_discovery.fingerprint.create_ecfp4_operator ¤

create_ecfp4_operator(
    n_bits: int = 2048,
    differentiable: bool = True,
    rngs: Rngs | None = None,
) -> CircularFingerprintOperator

Create ECFP4 (radius=2) fingerprint operator.

ECFP4 captures substructures within 4 bonds (radius 2).

Parameters:

Name Type Description Default
n_bits int

Number of fingerprint bits (default: 2048).

2048
differentiable bool

Use learned hash functions (default: True).

True
rngs Rngs | None

Random number generators.

None

Returns:

Type Description
CircularFingerprintOperator

Configured CircularFingerprintOperator.

create_ecfp6_operator¤

diffbio.operators.drug_discovery.fingerprint.create_ecfp6_operator ¤

create_ecfp6_operator(
    n_bits: int = 2048,
    differentiable: bool = True,
    rngs: Rngs | None = None,
) -> CircularFingerprintOperator

Create ECFP6 (radius=3) fingerprint operator.

ECFP6 captures substructures within 6 bonds (radius 3).

Parameters:

Name Type Description Default
n_bits int

Number of fingerprint bits (default: 2048).

2048
differentiable bool

Use learned hash functions (default: True).

True
rngs Rngs | None

Random number generators.

None

Returns:

Type Description
CircularFingerprintOperator

Configured CircularFingerprintOperator.

create_fcfp4_operator¤

diffbio.operators.drug_discovery.fingerprint.create_fcfp4_operator ¤

create_fcfp4_operator(
    n_bits: int = 2048,
    differentiable: bool = True,
    rngs: Rngs | None = None,
) -> CircularFingerprintOperator

Create FCFP4 (feature-based, radius=2) fingerprint operator.

FCFP4 uses pharmacophoric atom features instead of atomic properties. Better for finding molecules with similar biological activity.

Parameters:

Name Type Description Default
n_bits int

Number of fingerprint bits (default: 2048).

2048
differentiable bool

Use learned hash functions (default: True).

True
rngs Rngs | None

Random number generators.

None

Returns:

Type Description
CircularFingerprintOperator

Configured CircularFingerprintOperator.

create_similarity_operator¤

diffbio.operators.drug_discovery.similarity.create_similarity_operator ¤

create_similarity_operator(
    similarity_type: str = "tanimoto",
    temperature: float = 1.0,
) -> MolecularSimilarityOperator

Create a molecular similarity operator.

Parameters:

Name Type Description Default
similarity_type str

Type of similarity ("tanimoto", "cosine", "dice").

'tanimoto'
temperature float

Temperature parameter.

1.0

Returns:

Type Description
MolecularSimilarityOperator

Configured MolecularSimilarityOperator.

create_admet_predictor¤

diffbio.operators.drug_discovery.admet_predictor.create_admet_predictor ¤

create_admet_predictor(
    hidden_dim: int = 300,
    num_layers: int = 3,
    dropout_rate: float = 0.0,
    seed: int = 42,
) -> ADMETPredictor

Create an ADMET predictor with standard configuration.

Parameters:

Name Type Description Default
hidden_dim int

Hidden dimension for message passing.

300
num_layers int

Number of message passing steps.

3
dropout_rate float

Dropout rate.

0.0
seed int

Random seed.

42

Returns:

Type Description
ADMETPredictor

Configured ADMETPredictor.

create_maccs_operator¤

diffbio.operators.drug_discovery.maccs_keys.create_maccs_operator ¤

create_maccs_operator(
    differentiable: bool = True,
    temperature: float = 1.0,
    seed: int = 42,
) -> MACCSKeysOperator

Create a MACCS keys fingerprint operator.

Parameters:

Name Type Description Default
differentiable bool

Use learned pattern matching.

True
temperature float

Temperature for soft matching.

1.0
seed int

Random seed.

42

Returns:

Type Description
MACCSKeysOperator

Configured MACCSKeysOperator.

create_attentive_fp¤

diffbio.operators.drug_discovery.attentive_fp.create_attentive_fp ¤

create_attentive_fp(
    hidden_dim: int = 200,
    out_dim: int = 200,
    num_layers: int = 2,
    num_timesteps: int = 2,
    dropout_rate: float = 0.0,
    seed: int = 42,
) -> AttentiveFP

Create an AttentiveFP operator.

Parameters:

Name Type Description Default
hidden_dim int

Hidden dimension for GNN layers.

200
out_dim int

Output fingerprint dimension.

200
num_layers int

Number of atom-level attention layers.

2
num_timesteps int

Number of molecule-level GRU iterations.

2
dropout_rate float

Dropout rate.

0.0
seed int

Random seed.

42

Returns:

Type Description
AttentiveFP

Configured AttentiveFP.

Similarity Functions¤

tanimoto_similarity¤

diffbio.operators.drug_discovery.similarity.tanimoto_similarity ¤

tanimoto_similarity(
    a: ndarray, b: ndarray, eps: float = 1e-08
) -> ndarray

Compute differentiable Tanimoto similarity.

For continuous vectors, uses the generalized Tanimoto formula: T(a, b) = (a · b) / (|a|² + |b|² - a · b)

Parameters:

Name Type Description Default
a ndarray

First fingerprint vector.

required
b ndarray

Second fingerprint vector.

required
eps float

Small constant for numerical stability.

1e-08

Returns:

Type Description
ndarray

Similarity score in [0, 1].

cosine_similarity¤

diffbio.operators.drug_discovery.similarity.cosine_similarity ¤

cosine_similarity(
    a: ndarray, b: ndarray, eps: float = 1e-08
) -> ndarray

Compute cosine similarity.

Parameters:

Name Type Description Default
a ndarray

First vector.

required
b ndarray

Second vector.

required
eps float

Small constant for numerical stability.

1e-08

Returns:

Type Description
ndarray

Similarity score in [-1, 1].

dice_similarity¤

diffbio.operators.drug_discovery.similarity.dice_similarity ¤

dice_similarity(
    a: ndarray, b: ndarray, eps: float = 1e-08
) -> ndarray

Compute Dice similarity coefficient.

For continuous vectors

Dice(a, b) = 2 * (a · b) / (|a|² + |b|²)

Parameters:

Name Type Description Default
a ndarray

First vector.

required
b ndarray

Second vector.

required
eps float

Small constant for numerical stability.

1e-08

Returns:

Type Description
ndarray

Similarity score in [0, 1].

Graph Conversion Utilities¤

smiles_to_graph¤

diffbio.operators.drug_discovery.primitives.smiles_to_graph ¤

smiles_to_graph(smiles: str) -> dict[str, Any]

Convert a SMILES string to a molecular graph.

Parameters:

Name Type Description Default
smiles str

SMILES string representing a molecule.

required

Returns:

Type Description
dict[str, Any]

Dictionary containing: - node_features: (num_atoms, num_features) atom feature matrix - adjacency: (num_atoms, num_atoms) adjacency matrix - edge_features: (num_atoms, num_atoms, num_edge_features) bond features - num_nodes: number of atoms

Raises:

Type Description
ValueError

If SMILES string is invalid.

batch_smiles_to_graphs¤

diffbio.operators.drug_discovery.primitives.batch_smiles_to_graphs ¤

batch_smiles_to_graphs(
    smiles_list: list[str],
) -> dict[str, Any]

Convert a batch of SMILES strings to padded graph tensors.

Parameters:

Name Type Description Default
smiles_list list[str]

List of SMILES strings.

required

Returns:

Type Description
dict[str, Any]

Dictionary containing: - node_features: (batch_size, max_nodes, num_features) - adjacency: (batch_size, max_nodes, max_nodes) - edge_features: (batch_size, max_nodes, max_nodes, num_edge_features) - node_mask: (batch_size, max_nodes) mask for valid nodes

AtomFeatureConfig¤

diffbio.operators.drug_discovery.primitives.AtomFeatureConfig dataclass ¤

AtomFeatureConfig(
    num_atom_types: int = 12,
    max_degree: int = 6,
    charge_range: tuple[int, int] = (-2, 2),
    num_hybridization_types: int = 4,
    max_num_hydrogens: int = 4,
)

Configuration for atom feature extraction.

The default configuration produces 34 features:

  • Atom type: 12 dimensions (C, N, O, S, F, Cl, Br, I, P, Si, B, Other)
  • Degree: 7 dimensions (0-6)
  • Formal charge: 5 dimensions (-2 to +2)
  • Hybridization: 4 dimensions (SP, SP2, SP3, SP3D)
  • Aromaticity: 1 dimension (binary)
  • Num hydrogens: 5 dimensions (0-4)

Constants¤

DEFAULT_ATOM_FEATURES¤

diffbio.operators.drug_discovery.primitives.DEFAULT_ATOM_FEATURES module-attribute ¤

DEFAULT_ATOM_FEATURES = total_features

DEFAULT_ATOM_CONFIG¤

diffbio.operators.drug_discovery.primitives.DEFAULT_ATOM_CONFIG module-attribute ¤

DEFAULT_ATOM_CONFIG = AtomFeatureConfig()

Usage Examples¤

Property Prediction¤

from diffbio.operators.drug_discovery import (
    MolecularPropertyPredictor,
    MolecularPropertyConfig,
    smiles_to_graph,
    DEFAULT_ATOM_FEATURES,
)
from flax import nnx

# Create predictor
config = MolecularPropertyConfig(
    hidden_dim=64,
    num_message_passing_steps=3,
    num_output_tasks=1,
    in_features=DEFAULT_ATOM_FEATURES,
)
predictor = MolecularPropertyPredictor(config, rngs=nnx.Rngs(42))

# Convert SMILES to graph
node_features, adjacency, edge_features = smiles_to_graph("CCO")

# Predict properties
data = {
    "node_features": node_features,
    "adjacency": adjacency,
    "edge_features": edge_features,
}
result, _, _ = predictor.apply(data, {}, None)
predictions = result["predictions"]  # (1,)

Fingerprint Computation¤

from diffbio.operators.drug_discovery import (
    DifferentiableMolecularFingerprint,
    MolecularFingerprintConfig,
    smiles_to_graph,
    DEFAULT_ATOM_FEATURES,
)
from flax import nnx

# Create fingerprint operator
config = MolecularFingerprintConfig(
    fingerprint_dim=256,
    hidden_dim=128,
    num_layers=3,
    in_features=DEFAULT_ATOM_FEATURES,
    normalize=True,
)
fp_op = DifferentiableMolecularFingerprint(config, rngs=nnx.Rngs(42))

# Compute fingerprint
node_features, adjacency, _ = smiles_to_graph("c1ccccc1")
data = {"node_features": node_features, "adjacency": adjacency}
result, _, _ = fp_op.apply(data, {}, None)
fingerprint = result["fingerprint"]  # (256,)

Circular Fingerprint (ECFP/Morgan)¤

from diffbio.operators.drug_discovery import (
    CircularFingerprintOperator,
    CircularFingerprintConfig,
    create_ecfp4_operator,
    create_ecfp6_operator,
    create_fcfp4_operator,
    smiles_to_graph,
    DEFAULT_ATOM_FEATURES,
)
from flax import nnx

# Using factory function for ECFP4-like fingerprints
ecfp4_op = create_ecfp4_operator(n_bits=2048)

# Or with full configuration
config = CircularFingerprintConfig(
    radius=2,              # ECFP4 (2 * radius = 4)
    n_bits=2048,           # Fingerprint dimension
    use_chirality=False,   # Include stereochemistry
    use_features=False,    # Use FCFP instead of ECFP
    differentiable=True,   # Use learned hash functions
    hash_hidden_dim=128,   # Hidden dim for hash network
    temperature=1.0,       # Softmax temperature
    in_features=DEFAULT_ATOM_FEATURES,
)
fp_op = CircularFingerprintOperator(config, rngs=nnx.Rngs(42))

# Compute fingerprint from molecular graph
node_features, adjacency, _ = smiles_to_graph("c1ccccc1")
data = {
    "node_features": node_features,
    "adjacency": adjacency,
}
result, _, _ = fp_op.apply(data, {}, None)
fingerprint = result["fingerprint"]  # (2048,)

# Or compute directly from SMILES (RDKit mode)
config_rdkit = CircularFingerprintConfig(
    radius=2,
    n_bits=2048,
    differentiable=False,  # Use RDKit exact fingerprint
)
fp_op_rdkit = CircularFingerprintOperator(config_rdkit)
data = {"smiles": "c1ccccc1"}
result, _, _ = fp_op_rdkit.apply(data, {}, None)

Similarity Computation¤

from diffbio.operators.drug_discovery import (
    MolecularSimilarityOperator,
    MolecularSimilarityConfig,
    tanimoto_similarity,
)
from flax import nnx
import jax.numpy as jnp

# Using operator
config = MolecularSimilarityConfig(similarity_type="tanimoto")
sim_op = MolecularSimilarityOperator(config, rngs=nnx.Rngs(42))

fp1 = jnp.array([1.0, 0.5, 0.0, 0.8])
fp2 = jnp.array([0.9, 0.6, 0.1, 0.7])

data = {"fingerprint_a": fp1, "fingerprint_b": fp2}
result, _, _ = sim_op.apply(data, {}, None)
similarity = result["similarity"]

# Using standalone function
sim = tanimoto_similarity(fp1, fp2)

Batch Processing¤

from diffbio.operators.drug_discovery import batch_smiles_to_graphs

smiles_list = ["CCO", "CC(=O)O", "c1ccccc1"]
node_features, adjacency, edge_features, masks = batch_smiles_to_graphs(
    smiles_list,
    max_atoms=20,
)
# node_features: (3, 20, 34)
# adjacency: (3, 20, 20)
# edge_features: (3, 20, 20, 4)
# masks: (3, 20)

Gradient Computation¤

import jax
from flax import nnx
from diffbio.operators.drug_discovery import (
    create_property_predictor,
    smiles_to_graph,
    DEFAULT_ATOM_FEATURES,
)

predictor = create_property_predictor(
    hidden_dim=32,
    num_layers=2,
    num_tasks=1,
)

node_features, adjacency, edge_features = smiles_to_graph("CCO")
data = {
    "node_features": node_features,
    "adjacency": adjacency,
    "edge_features": edge_features,
}

def loss_fn(model, data):
    result, _, _ = model.apply(data, {}, None)
    return result["predictions"].sum()

# Compute gradients with nnx.grad
grads = nnx.grad(loss_fn)(predictor, data)

Input Specifications¤

MolecularPropertyPredictor¤

Key Shape Type Description
node_features (n, in_features) float32 Atom feature vectors
adjacency (n, n) float32 Adjacency matrix
edge_features (n, n, num_edge_features) float32 Optional bond features
node_mask (n,) float32 Optional mask for valid atoms

DifferentiableMolecularFingerprint¤

Key Shape Type Description
node_features (n, in_features) float32 Atom feature vectors
adjacency (n, n) float32 Adjacency matrix
edge_features (n, n, num_edge_features) float32 Optional bond features
node_mask (n,) float32 Optional mask for valid atoms

CircularFingerprintOperator¤

Differentiable mode (differentiable=True):

Key Shape Type Description
node_features (n, in_features) float32 Atom feature vectors
adjacency (n, n) float32 Adjacency matrix
node_mask (n,) float32 Optional mask for valid atoms

RDKit mode (differentiable=False):

Key Type Description
smiles str SMILES string

MolecularSimilarityOperator¤

Key Shape Type Description
fingerprint_a (dim,) float32 First fingerprint vector
fingerprint_b (dim,) float32 Second fingerprint vector

Output Specifications¤

MolecularPropertyPredictor¤

Key Shape Type Description
predictions (num_tasks,) float32 Property predictions
graph_representation (hidden_dim,) float32 Graph-level embedding

DifferentiableMolecularFingerprint¤

Key Shape Type Description
fingerprint (fingerprint_dim,) float32 Molecular fingerprint

CircularFingerprintOperator¤

Key Shape Type Description
fingerprint (n_bits,) float32 Circular fingerprint (soft probabilities in differentiable mode, binary in RDKit mode)

MolecularSimilarityOperator¤

Key Shape Type Description
similarity () float32 Similarity score

Atom Feature Dimensions¤

Feature Dimensions Description
Atom type 10 C, N, O, S, F, Cl, Br, I, P, other
Degree 6 0-5+ neighbors
Formal charge 5 -2 to +2
Hybridization 5 SP, SP2, SP3, SP3D, SP3D2
Aromaticity 1 Binary
Hydrogens 5 0-4+
In ring 1 Binary
Chiral 1 Binary
Total 34 DEFAULT_ATOM_FEATURES