Skip to content

Molecular Dynamics Operators API¤

Differentiable operators for molecular dynamics simulations using JAX-MD.

ForceFieldOperator¤

diffbio.operators.molecular_dynamics.force_field.ForceFieldOperator ¤

ForceFieldOperator(
    config: ForceFieldConfig, *, rngs: Rngs | None = None
)

Bases: OperatorModule

Differentiable force field operator using JAX-MD.

Computes potential energy and forces for a system of particles using classical pairwise potentials. Forces are computed automatically via JAX's automatic differentiation.

Supported potentials
  • lennard_jones: Standard 12-6 LJ potential
  • morse: Morse potential for bonded interactions
  • soft_sphere: Soft repulsive potential
Example
config = ForceFieldConfig(potential_type="lennard_jones", box_size=10.0)
operator = ForceFieldOperator(config, rngs=nnx.Rngs(42))
data = {"positions": positions}  # (n_particles, dim)
result, state, meta = operator.apply(data, {}, None)
energy = result["energy"]  # scalar
forces = result["forces"]  # (n_particles, dim)

Parameters:

Name Type Description Default
config ForceFieldConfig

Force field configuration.

required
rngs Rngs | None

Flax NNX random number generators.

None

apply ¤

apply(
    data: dict[str, Any],
    state: dict[str, Any],
    metadata: dict[str, Any] | None,
    random_params: Any = None,
    stats: dict[str, Any] | None = None,
) -> tuple[
    dict[str, Any], dict[str, Any], dict[str, Any] | None
]

Compute energy and forces for particle positions.

Parameters:

Name Type Description Default
data dict[str, Any]

Input data containing: - positions: Particle positions (n_particles, dim) or (batch, n_particles, dim)

required
state dict[str, Any]

Per-element state (passed through).

required
metadata dict[str, Any] | None

Optional metadata.

required
random_params Any

Unused random parameters.

None
stats dict[str, Any] | None

Optional statistics dictionary.

None

Returns:

Type Description
tuple[dict[str, Any], dict[str, Any], dict[str, Any] | None]

Tuple of: - data with added "energy" and "forces" keys - unchanged state - unchanged metadata

ForceFieldConfig¤

diffbio.operators.molecular_dynamics.force_field.ForceFieldConfig dataclass ¤

ForceFieldConfig(
    potential_type: str = "lennard_jones",
    sigma: float = 1.0,
    epsilon: float = 1.0,
    cutoff: float | None = 2.5,
    box_size: float | None = None,
    alpha: float = 5.0,
    geometric_loss_weight: float = 0.0,
)

Bases: OperatorConfig

Configuration for force field operator.

Attributes:

Name Type Description
potential_type str

Type of potential ("lennard_jones", "morse", "soft_sphere").

sigma float

Length scale parameter (particle diameter).

epsilon float

Energy scale parameter (well depth).

cutoff float | None

Cutoff distance for interactions (in units of sigma). None for no cutoff.

box_size float | None

Size of periodic box. None for non-periodic.

alpha float

Morse potential width parameter (only for morse).

reference_positions float

Optional reference positions for computing geometric losses (chamfer/EMD) via artifex. If provided, chamfer_distance and earth_mover_distance are added to the output dict. Shape must match particle positions.

geometric_loss_weight float

Weight for geometric loss terms.

MDIntegratorOperator¤

diffbio.operators.molecular_dynamics.integrator.MDIntegratorOperator ¤

MDIntegratorOperator(
    config: MDIntegratorConfig,
    *,
    rngs: Rngs | None = None,
    name: str | None = None,
)

Bases: OperatorModule

Differentiable MD integrator operator using JAX-MD.

Evolves particle positions and velocities over time using classical molecular dynamics integration schemes.

Supported integrators
  • velocity_verlet: Symplectic velocity Verlet (NVE)
  • nvt_langevin: Langevin dynamics for NVT ensemble
Example
config = MDIntegratorConfig(dt=0.001, n_steps=1000, box_size=10.0)
integrator = MDIntegratorOperator(config, rngs=nnx.Rngs(42))
data = {"positions": positions, "velocities": velocities}
result, state, meta = integrator.apply(data, {}, None)
final_positions = result["positions"]
trajectory = result["trajectory"]

Parameters:

Name Type Description Default
config MDIntegratorConfig

Integrator configuration.

required
rngs Rngs | None

Flax NNX random number generators.

None
name str | None

Optional name for the operator.

None

apply ¤

apply(
    data: dict[str, Any],
    state: dict[str, Any],
    metadata: dict[str, Any] | None,
    random_params: Any = None,
    stats: dict[str, Any] | None = None,
) -> tuple[
    dict[str, Any], dict[str, Any], dict[str, Any] | None
]

Run MD simulation for specified number of steps.

Parameters:

Name Type Description Default
data dict[str, Any]

Input data containing: - positions: Initial particle positions (n_particles, dim) - velocities: Initial particle velocities (n_particles, dim)

required
state dict[str, Any]

Per-element state (passed through).

required
metadata dict[str, Any] | None

Optional metadata.

required
random_params Any

Unused random parameters.

None
stats dict[str, Any] | None

Optional statistics dictionary.

None

Returns:

Type Description
tuple[dict[str, Any], dict[str, Any], dict[str, Any] | None]

Tuple of: - data with updated positions/velocities and trajectory - unchanged state - unchanged metadata

MDIntegratorConfig¤

diffbio.operators.molecular_dynamics.integrator.MDIntegratorConfig dataclass ¤

MDIntegratorConfig(
    integrator_type: str = "velocity_verlet",
    dt: float = 0.001,
    n_steps: int = 100,
    box_size: float | None = 10.0,
    potential_type: str = "lennard_jones",
    sigma: float = 1.0,
    epsilon: float = 1.0,
    mass: float = 1.0,
    kT: float = 1.0,
    gamma: float = 1.0,
)

Bases: OperatorConfig

Configuration for MD integrator operator.

Attributes:

Name Type Description
integrator_type str

Type of integrator ("velocity_verlet", "nvt_langevin").

dt float

Time step for integration.

n_steps int

Number of integration steps.

box_size float | None

Size of periodic box. None for non-periodic.

potential_type str

Type of potential ("lennard_jones", "morse", "soft_sphere").

sigma float

Sigma parameter for potential (length scale).

epsilon float

Epsilon parameter for potential (energy scale).

mass float

Particle mass (uniform for all particles).

kT float

Thermal energy for Langevin thermostat.

gamma float

Friction coefficient for Langevin dynamics.

PotentialType¤

diffbio.operators.molecular_dynamics.primitives.PotentialType ¤

Bases: StrEnum

Enumeration of supported potential types.

Factory Functions¤

create_force_field¤

diffbio.operators.molecular_dynamics.force_field.create_force_field ¤

create_force_field(
    potential_type: str | PotentialType = LENNARD_JONES,
    sigma: float = 1.0,
    epsilon: float = 1.0,
    cutoff: float | None = 2.5,
    box_size: float | None = None,
    alpha: float = 5.0,
    seed: int = 42,
) -> ForceFieldOperator

Create a force field operator with specified potential.

Parameters:

Name Type Description Default
potential_type str | PotentialType

Type of potential ("lennard_jones", "morse", "soft_sphere") or PotentialType enum.

LENNARD_JONES
sigma float

Particle diameter (length scale).

1.0
epsilon float

Well depth (energy scale).

1.0
cutoff float | None

Cutoff distance in units of sigma. None for no cutoff.

2.5
box_size float | None

Periodic box size. None for non-periodic.

None
alpha float

Morse potential width parameter.

5.0
seed int

Random seed for initialization.

42

Returns:

Type Description
ForceFieldOperator

Configured ForceFieldOperator.

create_lennard_jones_operator¤

diffbio.operators.molecular_dynamics.force_field.create_lennard_jones_operator ¤

create_lennard_jones_operator(
    sigma: float = 1.0,
    epsilon: float = 1.0,
    cutoff: float | None = 2.5,
    box_size: float | None = None,
    seed: int = 42,
) -> ForceFieldOperator

Create a Lennard-Jones force field operator.

This is a convenience function for creating a force field operator with Lennard-Jones potential.

Parameters:

Name Type Description Default
sigma float

Particle diameter (length scale).

1.0
epsilon float

Well depth (energy scale).

1.0
cutoff float | None

Cutoff distance in units of sigma. None for no cutoff.

2.5
box_size float | None

Periodic box size. None for non-periodic.

None
seed int

Random seed for initialization.

42

Returns:

Type Description
ForceFieldOperator

Configured ForceFieldOperator.

create_integrator¤

diffbio.operators.molecular_dynamics.integrator.create_integrator ¤

create_integrator(
    integrator_type: str = "velocity_verlet",
    dt: float = 0.001,
    n_steps: int = 100,
    box_size: float | None = 10.0,
    potential_type: str | PotentialType = LENNARD_JONES,
    sigma: float = 1.0,
    epsilon: float = 1.0,
    mass: float = 1.0,
    kT: float = 1.0,
    gamma: float = 1.0,
    seed: int = 42,
) -> MDIntegratorOperator

Create an MD integrator operator.

Parameters:

Name Type Description Default
integrator_type str

Type of integrator ("velocity_verlet", "nvt_langevin").

'velocity_verlet'
dt float

Time step for integration.

0.001
n_steps int

Number of integration steps.

100
box_size float | None

Periodic box size. None for non-periodic.

10.0
potential_type str | PotentialType

Type of potential ("lennard_jones", "morse", "soft_sphere") or PotentialType enum.

LENNARD_JONES
sigma float

Sigma parameter for potential (length scale).

1.0
epsilon float

Epsilon parameter for potential (energy scale).

1.0
mass float

Particle mass.

1.0
kT float

Thermal energy for Langevin thermostat.

1.0
gamma float

Friction coefficient for Langevin dynamics.

1.0
seed int

Random seed for initialization.

42

Returns:

Type Description
MDIntegratorOperator

Configured MDIntegratorOperator.

create_verlet_integrator¤

diffbio.operators.molecular_dynamics.integrator.create_verlet_integrator ¤

create_verlet_integrator(
    dt: float = 0.001,
    n_steps: int = 100,
    box_size: float | None = 10.0,
    sigma: float = 1.0,
    epsilon: float = 1.0,
    seed: int = 42,
) -> MDIntegratorOperator

Create a velocity Verlet integrator operator.

This is a convenience function for creating an NVE integrator with velocity Verlet algorithm.

Parameters:

Name Type Description Default
dt float

Time step for integration.

0.001
n_steps int

Number of integration steps.

100
box_size float | None

Periodic box size. None for non-periodic.

10.0
sigma float

Sigma parameter for potential.

1.0
epsilon float

Epsilon parameter for potential.

1.0
seed int

Random seed for initialization.

42

Returns:

Type Description
MDIntegratorOperator

Configured MDIntegratorOperator.

Primitive Functions¤

create_displacement_fn¤

diffbio.operators.molecular_dynamics.primitives.create_displacement_fn ¤

create_displacement_fn(
    box_size: float | None = None,
) -> tuple[Callable, Callable]

Create displacement and shift functions based on boundary conditions.

Parameters:

Name Type Description Default
box_size float | None

Size of periodic box. None for non-periodic (free) boundaries.

None

Returns:

Type Description
tuple[Callable, Callable]

Tuple of (displacement_fn, shift_fn) where: - displacement_fn: computes displacement vector between two points - shift_fn: applies displacement to a position respecting boundaries

create_energy_fn¤

diffbio.operators.molecular_dynamics.primitives.create_energy_fn ¤

create_energy_fn(
    displacement_fn: Callable,
    potential_type: PotentialType | str = LENNARD_JONES,
    sigma: float = 1.0,
    epsilon: float = 1.0,
    cutoff: float | None = None,
    alpha: float = 5.0,
) -> Callable

Create energy function for the specified potential.

Parameters:

Name Type Description Default
displacement_fn Callable

Displacement function from create_displacement_fn.

required
potential_type PotentialType | str

Type of potential to use.

LENNARD_JONES
sigma float

Length scale parameter (particle diameter).

1.0
epsilon float

Energy scale parameter (well depth).

1.0
cutoff float | None

Cutoff distance for interactions. None for no cutoff.

None
alpha float

Morse potential width parameter (only for morse).

5.0

Returns:

Type Description
Callable

Energy function that takes positions and returns total energy.

Raises:

Type Description
ValueError

If potential_type is not recognized.

create_force_fn¤

diffbio.operators.molecular_dynamics.primitives.create_force_fn ¤

create_force_fn(energy_fn: Callable) -> Callable

Create force function from energy function.

Forces are computed as the negative gradient of the energy.

Parameters:

Name Type Description Default
energy_fn Callable

Energy function that takes positions and returns energy.

required

Returns:

Type Description
Callable

Force function that takes positions and returns forces.

Usage Examples¤

Force Field Computation¤

from diffbio.operators.molecular_dynamics import (
    create_force_field,
    PotentialType,
)
import jax
import jax.numpy as jnp

# Create Lennard-Jones force field
force_field = create_force_field(
    potential_type=PotentialType.LENNARD_JONES,
    sigma=1.0,
    epsilon=1.0,
    box_size=10.0,
)

# Generate positions
positions = jax.random.uniform(
    jax.random.PRNGKey(0), (20, 3), minval=0, maxval=10.0
)

# Compute energy and forces
result, _, _ = force_field.apply({"positions": positions}, {}, None)
energy = result["energy"]  # scalar
forces = result["forces"]  # (20, 3)

MD Simulation¤

from diffbio.operators.molecular_dynamics import create_verlet_integrator
import jax

# Create integrator
integrator = create_verlet_integrator(
    dt=0.001,
    n_steps=1000,
    box_size=10.0,
)

# Initial conditions
key = jax.random.PRNGKey(0)
key1, key2 = jax.random.split(key)
positions = jax.random.uniform(key1, (20, 3), minval=2, maxval=8.0)
velocities = jax.random.normal(key2, (20, 3)) * 0.1

# Run simulation
result, _, _ = integrator.apply(
    {"positions": positions, "velocities": velocities}, {}, None
)
trajectory = result["trajectory"]  # (1001, 20, 3)

Full Configuration¤

from diffbio.operators.molecular_dynamics import (
    ForceFieldOperator,
    ForceFieldConfig,
    MDIntegratorOperator,
    MDIntegratorConfig,
)
from flax import nnx

# Force field with custom config
ff_config = ForceFieldConfig(
    potential_type="morse",
    sigma=1.0,
    epsilon=2.0,
    alpha=5.0,
    box_size=15.0,
)
force_field = ForceFieldOperator(ff_config, rngs=nnx.Rngs(42))

# Integrator with custom config
int_config = MDIntegratorConfig(
    integrator_type="nvt_langevin",
    dt=0.002,
    n_steps=500,
    box_size=15.0,
    kT=1.0,
    gamma=0.5,
)
integrator = MDIntegratorOperator(int_config, rngs=nnx.Rngs(42))

Batched Processing¤

import jax

# Batch of configurations
batch_size = 8
n_particles = 20
dim = 3

positions = jax.random.uniform(
    jax.random.PRNGKey(0),
    (batch_size, n_particles, dim),
    minval=0,
    maxval=10.0,
)

# Force field handles batched input
result, _, _ = force_field.apply({"positions": positions}, {}, None)
energies = result["energy"]  # (8,)
forces = result["forces"]    # (8, 20, 3)

Gradient Computation¤

import jax
from flax import nnx

force_field = create_force_field(box_size=10.0)

def loss_fn(positions):
    result, _, _ = force_field.apply({"positions": positions}, {}, None)
    return result["energy"]

# Compute gradients w.r.t. positions
grads = jax.grad(loss_fn)(positions)

Input Specifications¤

ForceFieldOperator¤

Key Shape Type Description
positions (n, dim) or (batch, n, dim) float32 Particle positions

MDIntegratorOperator¤

Key Shape Type Description
positions (n, dim) float32 Initial positions
velocities (n, dim) float32 Initial velocities

Output Specifications¤

ForceFieldOperator¤

Key Shape Type Description
positions same as input float32 Original positions
energy () or (batch,) float32 Total potential energy
forces same as positions float32 Force vectors

MDIntegratorOperator¤

Key Shape Type Description
positions (n, dim) float32 Final positions
velocities (n, dim) float32 Final velocities
trajectory (steps+1, n, dim) float32 Position trajectory

Potential Parameters¤

Potential Parameters Description
Lennard-Jones sigma, epsilon, cutoff Standard 12-6 potential
Soft Sphere sigma, epsilon Purely repulsive
Morse sigma, epsilon, alpha Anharmonic bonded potential