Molecular Dynamics Operators API¤
Differentiable operators for molecular dynamics simulations using JAX-MD.
ForceFieldOperator¤
diffbio.operators.molecular_dynamics.force_field.ForceFieldOperator
¤
ForceFieldOperator(
config: ForceFieldConfig, *, rngs: Rngs | None = None
)
Bases: OperatorModule
Differentiable force field operator using JAX-MD.
Computes potential energy and forces for a system of particles using classical pairwise potentials. Forces are computed automatically via JAX's automatic differentiation.
Supported potentials
- lennard_jones: Standard 12-6 LJ potential
- morse: Morse potential for bonded interactions
- soft_sphere: Soft repulsive potential
Example
config = ForceFieldConfig(potential_type="lennard_jones", box_size=10.0)
operator = ForceFieldOperator(config, rngs=nnx.Rngs(42))
data = {"positions": positions} # (n_particles, dim)
result, state, meta = operator.apply(data, {}, None)
energy = result["energy"] # scalar
forces = result["forces"] # (n_particles, dim)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
ForceFieldConfig
|
Force field configuration. |
required |
rngs
|
Rngs | None
|
Flax NNX random number generators. |
None
|
apply
¤
apply(
data: dict[str, Any],
state: dict[str, Any],
metadata: dict[str, Any] | None,
random_params: Any = None,
stats: dict[str, Any] | None = None,
) -> tuple[
dict[str, Any], dict[str, Any], dict[str, Any] | None
]
Compute energy and forces for particle positions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
dict[str, Any]
|
Input data containing: - positions: Particle positions (n_particles, dim) or (batch, n_particles, dim) |
required |
state
|
dict[str, Any]
|
Per-element state (passed through). |
required |
metadata
|
dict[str, Any] | None
|
Optional metadata. |
required |
random_params
|
Any
|
Unused random parameters. |
None
|
stats
|
dict[str, Any] | None
|
Optional statistics dictionary. |
None
|
Returns:
| Type | Description |
|---|---|
tuple[dict[str, Any], dict[str, Any], dict[str, Any] | None]
|
Tuple of: - data with added "energy" and "forces" keys - unchanged state - unchanged metadata |
ForceFieldConfig¤
diffbio.operators.molecular_dynamics.force_field.ForceFieldConfig
dataclass
¤
ForceFieldConfig(
potential_type: str = "lennard_jones",
sigma: float = 1.0,
epsilon: float = 1.0,
cutoff: float | None = 2.5,
box_size: float | None = None,
alpha: float = 5.0,
geometric_loss_weight: float = 0.0,
)
Bases: OperatorConfig
Configuration for force field operator.
Attributes:
| Name | Type | Description |
|---|---|---|
potential_type |
str
|
Type of potential ("lennard_jones", "morse", "soft_sphere"). |
sigma |
float
|
Length scale parameter (particle diameter). |
epsilon |
float
|
Energy scale parameter (well depth). |
cutoff |
float | None
|
Cutoff distance for interactions (in units of sigma). None for no cutoff. |
box_size |
float | None
|
Size of periodic box. None for non-periodic. |
alpha |
float
|
Morse potential width parameter (only for morse). |
reference_positions |
float
|
Optional reference positions for computing
geometric losses (chamfer/EMD) via artifex. If provided,
|
geometric_loss_weight |
float
|
Weight for geometric loss terms. |
MDIntegratorOperator¤
diffbio.operators.molecular_dynamics.integrator.MDIntegratorOperator
¤
MDIntegratorOperator(
config: MDIntegratorConfig,
*,
rngs: Rngs | None = None,
name: str | None = None,
)
Bases: OperatorModule
Differentiable MD integrator operator using JAX-MD.
Evolves particle positions and velocities over time using classical molecular dynamics integration schemes.
Supported integrators
- velocity_verlet: Symplectic velocity Verlet (NVE)
- nvt_langevin: Langevin dynamics for NVT ensemble
Example
config = MDIntegratorConfig(dt=0.001, n_steps=1000, box_size=10.0)
integrator = MDIntegratorOperator(config, rngs=nnx.Rngs(42))
data = {"positions": positions, "velocities": velocities}
result, state, meta = integrator.apply(data, {}, None)
final_positions = result["positions"]
trajectory = result["trajectory"]
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
MDIntegratorConfig
|
Integrator configuration. |
required |
rngs
|
Rngs | None
|
Flax NNX random number generators. |
None
|
name
|
str | None
|
Optional name for the operator. |
None
|
apply
¤
apply(
data: dict[str, Any],
state: dict[str, Any],
metadata: dict[str, Any] | None,
random_params: Any = None,
stats: dict[str, Any] | None = None,
) -> tuple[
dict[str, Any], dict[str, Any], dict[str, Any] | None
]
Run MD simulation for specified number of steps.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
dict[str, Any]
|
Input data containing: - positions: Initial particle positions (n_particles, dim) - velocities: Initial particle velocities (n_particles, dim) |
required |
state
|
dict[str, Any]
|
Per-element state (passed through). |
required |
metadata
|
dict[str, Any] | None
|
Optional metadata. |
required |
random_params
|
Any
|
Unused random parameters. |
None
|
stats
|
dict[str, Any] | None
|
Optional statistics dictionary. |
None
|
Returns:
| Type | Description |
|---|---|
tuple[dict[str, Any], dict[str, Any], dict[str, Any] | None]
|
Tuple of: - data with updated positions/velocities and trajectory - unchanged state - unchanged metadata |
MDIntegratorConfig¤
diffbio.operators.molecular_dynamics.integrator.MDIntegratorConfig
dataclass
¤
MDIntegratorConfig(
integrator_type: str = "velocity_verlet",
dt: float = 0.001,
n_steps: int = 100,
box_size: float | None = 10.0,
potential_type: str = "lennard_jones",
sigma: float = 1.0,
epsilon: float = 1.0,
mass: float = 1.0,
kT: float = 1.0,
gamma: float = 1.0,
)
Bases: OperatorConfig
Configuration for MD integrator operator.
Attributes:
| Name | Type | Description |
|---|---|---|
integrator_type |
str
|
Type of integrator ("velocity_verlet", "nvt_langevin"). |
dt |
float
|
Time step for integration. |
n_steps |
int
|
Number of integration steps. |
box_size |
float | None
|
Size of periodic box. None for non-periodic. |
potential_type |
str
|
Type of potential ("lennard_jones", "morse", "soft_sphere"). |
sigma |
float
|
Sigma parameter for potential (length scale). |
epsilon |
float
|
Epsilon parameter for potential (energy scale). |
mass |
float
|
Particle mass (uniform for all particles). |
kT |
float
|
Thermal energy for Langevin thermostat. |
gamma |
float
|
Friction coefficient for Langevin dynamics. |
PotentialType¤
diffbio.operators.molecular_dynamics.primitives.PotentialType
¤
Bases: StrEnum
Enumeration of supported potential types.
Factory Functions¤
create_force_field¤
diffbio.operators.molecular_dynamics.force_field.create_force_field
¤
create_force_field(
potential_type: str | PotentialType = LENNARD_JONES,
sigma: float = 1.0,
epsilon: float = 1.0,
cutoff: float | None = 2.5,
box_size: float | None = None,
alpha: float = 5.0,
seed: int = 42,
) -> ForceFieldOperator
Create a force field operator with specified potential.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
potential_type
|
str | PotentialType
|
Type of potential ("lennard_jones", "morse", "soft_sphere") or PotentialType enum. |
LENNARD_JONES
|
sigma
|
float
|
Particle diameter (length scale). |
1.0
|
epsilon
|
float
|
Well depth (energy scale). |
1.0
|
cutoff
|
float | None
|
Cutoff distance in units of sigma. None for no cutoff. |
2.5
|
box_size
|
float | None
|
Periodic box size. None for non-periodic. |
None
|
alpha
|
float
|
Morse potential width parameter. |
5.0
|
seed
|
int
|
Random seed for initialization. |
42
|
Returns:
| Type | Description |
|---|---|
ForceFieldOperator
|
Configured ForceFieldOperator. |
create_lennard_jones_operator¤
diffbio.operators.molecular_dynamics.force_field.create_lennard_jones_operator
¤
create_lennard_jones_operator(
sigma: float = 1.0,
epsilon: float = 1.0,
cutoff: float | None = 2.5,
box_size: float | None = None,
seed: int = 42,
) -> ForceFieldOperator
Create a Lennard-Jones force field operator.
This is a convenience function for creating a force field operator with Lennard-Jones potential.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sigma
|
float
|
Particle diameter (length scale). |
1.0
|
epsilon
|
float
|
Well depth (energy scale). |
1.0
|
cutoff
|
float | None
|
Cutoff distance in units of sigma. None for no cutoff. |
2.5
|
box_size
|
float | None
|
Periodic box size. None for non-periodic. |
None
|
seed
|
int
|
Random seed for initialization. |
42
|
Returns:
| Type | Description |
|---|---|
ForceFieldOperator
|
Configured ForceFieldOperator. |
create_integrator¤
diffbio.operators.molecular_dynamics.integrator.create_integrator
¤
create_integrator(
integrator_type: str = "velocity_verlet",
dt: float = 0.001,
n_steps: int = 100,
box_size: float | None = 10.0,
potential_type: str | PotentialType = LENNARD_JONES,
sigma: float = 1.0,
epsilon: float = 1.0,
mass: float = 1.0,
kT: float = 1.0,
gamma: float = 1.0,
seed: int = 42,
) -> MDIntegratorOperator
Create an MD integrator operator.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
integrator_type
|
str
|
Type of integrator ("velocity_verlet", "nvt_langevin"). |
'velocity_verlet'
|
dt
|
float
|
Time step for integration. |
0.001
|
n_steps
|
int
|
Number of integration steps. |
100
|
box_size
|
float | None
|
Periodic box size. None for non-periodic. |
10.0
|
potential_type
|
str | PotentialType
|
Type of potential ("lennard_jones", "morse", "soft_sphere") or PotentialType enum. |
LENNARD_JONES
|
sigma
|
float
|
Sigma parameter for potential (length scale). |
1.0
|
epsilon
|
float
|
Epsilon parameter for potential (energy scale). |
1.0
|
mass
|
float
|
Particle mass. |
1.0
|
kT
|
float
|
Thermal energy for Langevin thermostat. |
1.0
|
gamma
|
float
|
Friction coefficient for Langevin dynamics. |
1.0
|
seed
|
int
|
Random seed for initialization. |
42
|
Returns:
| Type | Description |
|---|---|
MDIntegratorOperator
|
Configured MDIntegratorOperator. |
create_verlet_integrator¤
diffbio.operators.molecular_dynamics.integrator.create_verlet_integrator
¤
create_verlet_integrator(
dt: float = 0.001,
n_steps: int = 100,
box_size: float | None = 10.0,
sigma: float = 1.0,
epsilon: float = 1.0,
seed: int = 42,
) -> MDIntegratorOperator
Create a velocity Verlet integrator operator.
This is a convenience function for creating an NVE integrator with velocity Verlet algorithm.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dt
|
float
|
Time step for integration. |
0.001
|
n_steps
|
int
|
Number of integration steps. |
100
|
box_size
|
float | None
|
Periodic box size. None for non-periodic. |
10.0
|
sigma
|
float
|
Sigma parameter for potential. |
1.0
|
epsilon
|
float
|
Epsilon parameter for potential. |
1.0
|
seed
|
int
|
Random seed for initialization. |
42
|
Returns:
| Type | Description |
|---|---|
MDIntegratorOperator
|
Configured MDIntegratorOperator. |
Primitive Functions¤
create_displacement_fn¤
diffbio.operators.molecular_dynamics.primitives.create_displacement_fn
¤
Create displacement and shift functions based on boundary conditions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
box_size
|
float | None
|
Size of periodic box. None for non-periodic (free) boundaries. |
None
|
Returns:
| Type | Description |
|---|---|
tuple[Callable, Callable]
|
Tuple of (displacement_fn, shift_fn) where: - displacement_fn: computes displacement vector between two points - shift_fn: applies displacement to a position respecting boundaries |
create_energy_fn¤
diffbio.operators.molecular_dynamics.primitives.create_energy_fn
¤
create_energy_fn(
displacement_fn: Callable,
potential_type: PotentialType | str = LENNARD_JONES,
sigma: float = 1.0,
epsilon: float = 1.0,
cutoff: float | None = None,
alpha: float = 5.0,
) -> Callable
Create energy function for the specified potential.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
displacement_fn
|
Callable
|
Displacement function from create_displacement_fn. |
required |
potential_type
|
PotentialType | str
|
Type of potential to use. |
LENNARD_JONES
|
sigma
|
float
|
Length scale parameter (particle diameter). |
1.0
|
epsilon
|
float
|
Energy scale parameter (well depth). |
1.0
|
cutoff
|
float | None
|
Cutoff distance for interactions. None for no cutoff. |
None
|
alpha
|
float
|
Morse potential width parameter (only for morse). |
5.0
|
Returns:
| Type | Description |
|---|---|
Callable
|
Energy function that takes positions and returns total energy. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If potential_type is not recognized. |
create_force_fn¤
diffbio.operators.molecular_dynamics.primitives.create_force_fn
¤
Create force function from energy function.
Forces are computed as the negative gradient of the energy.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
energy_fn
|
Callable
|
Energy function that takes positions and returns energy. |
required |
Returns:
| Type | Description |
|---|---|
Callable
|
Force function that takes positions and returns forces. |
Usage Examples¤
Force Field Computation¤
from diffbio.operators.molecular_dynamics import (
create_force_field,
PotentialType,
)
import jax
import jax.numpy as jnp
# Create Lennard-Jones force field
force_field = create_force_field(
potential_type=PotentialType.LENNARD_JONES,
sigma=1.0,
epsilon=1.0,
box_size=10.0,
)
# Generate positions
positions = jax.random.uniform(
jax.random.PRNGKey(0), (20, 3), minval=0, maxval=10.0
)
# Compute energy and forces
result, _, _ = force_field.apply({"positions": positions}, {}, None)
energy = result["energy"] # scalar
forces = result["forces"] # (20, 3)
MD Simulation¤
from diffbio.operators.molecular_dynamics import create_verlet_integrator
import jax
# Create integrator
integrator = create_verlet_integrator(
dt=0.001,
n_steps=1000,
box_size=10.0,
)
# Initial conditions
key = jax.random.PRNGKey(0)
key1, key2 = jax.random.split(key)
positions = jax.random.uniform(key1, (20, 3), minval=2, maxval=8.0)
velocities = jax.random.normal(key2, (20, 3)) * 0.1
# Run simulation
result, _, _ = integrator.apply(
{"positions": positions, "velocities": velocities}, {}, None
)
trajectory = result["trajectory"] # (1001, 20, 3)
Full Configuration¤
from diffbio.operators.molecular_dynamics import (
ForceFieldOperator,
ForceFieldConfig,
MDIntegratorOperator,
MDIntegratorConfig,
)
from flax import nnx
# Force field with custom config
ff_config = ForceFieldConfig(
potential_type="morse",
sigma=1.0,
epsilon=2.0,
alpha=5.0,
box_size=15.0,
)
force_field = ForceFieldOperator(ff_config, rngs=nnx.Rngs(42))
# Integrator with custom config
int_config = MDIntegratorConfig(
integrator_type="nvt_langevin",
dt=0.002,
n_steps=500,
box_size=15.0,
kT=1.0,
gamma=0.5,
)
integrator = MDIntegratorOperator(int_config, rngs=nnx.Rngs(42))
Batched Processing¤
import jax
# Batch of configurations
batch_size = 8
n_particles = 20
dim = 3
positions = jax.random.uniform(
jax.random.PRNGKey(0),
(batch_size, n_particles, dim),
minval=0,
maxval=10.0,
)
# Force field handles batched input
result, _, _ = force_field.apply({"positions": positions}, {}, None)
energies = result["energy"] # (8,)
forces = result["forces"] # (8, 20, 3)
Gradient Computation¤
import jax
from flax import nnx
force_field = create_force_field(box_size=10.0)
def loss_fn(positions):
result, _, _ = force_field.apply({"positions": positions}, {}, None)
return result["energy"]
# Compute gradients w.r.t. positions
grads = jax.grad(loss_fn)(positions)
Input Specifications¤
ForceFieldOperator¤
| Key | Shape | Type | Description |
|---|---|---|---|
positions |
(n, dim) or (batch, n, dim) | float32 | Particle positions |
MDIntegratorOperator¤
| Key | Shape | Type | Description |
|---|---|---|---|
positions |
(n, dim) | float32 | Initial positions |
velocities |
(n, dim) | float32 | Initial velocities |
Output Specifications¤
ForceFieldOperator¤
| Key | Shape | Type | Description |
|---|---|---|---|
positions |
same as input | float32 | Original positions |
energy |
() or (batch,) | float32 | Total potential energy |
forces |
same as positions | float32 | Force vectors |
MDIntegratorOperator¤
| Key | Shape | Type | Description |
|---|---|---|---|
positions |
(n, dim) | float32 | Final positions |
velocities |
(n, dim) | float32 | Final velocities |
trajectory |
(steps+1, n, dim) | float32 | Position trajectory |
Potential Parameters¤
| Potential | Parameters | Description |
|---|---|---|
| Lennard-Jones | sigma, epsilon, cutoff | Standard 12-6 potential |
| Soft Sphere | sigma, epsilon | Purely repulsive |
| Morse | sigma, epsilon, alpha | Anharmonic bonded potential |