Molecular Dynamics Operators¤
DiffBio provides differentiable operators for molecular dynamics simulations, wrapping JAX-MD for seamless integration with DiffBio pipelines.
Molecular Dynamics Fully Differentiable
Overview¤
Molecular dynamics operators enable gradient-based optimization of molecular simulations:
- ForceFieldOperator: Compute energies and forces from particle positions
- MDIntegratorOperator: Time integration for MD simulations (NVE, NVT)
Architecture¤
graph TB
A["Particle Positions<br/>(n_particles, dim)"] --> B["ForceFieldOperator<br/>Lennard-Jones / Morse /<br/>Soft Sphere Potential"]
B --> E1["Energy (scalar)"]
B --> F1["Forces (n_particles, dim)<br/>F = -∇E (auto-differentiation)"]
B --> C["MDIntegratorOperator<br/>Velocity Verlet (NVE) or<br/>Langevin Dynamics (NVT)"]
C --> D1["Final Positions"]
C --> D2["Final Velocities"]
C --> D3["Trajectory<br/>(n_steps, n_particles, dim)"]
style A fill:#d1fae5,stroke:#059669,color:#064e3b
style B fill:#e0e7ff,stroke:#4338ca,color:#312e81
style E1 fill:#d1fae5,stroke:#059669,color:#064e3b
style F1 fill:#d1fae5,stroke:#059669,color:#064e3b
style C fill:#e0e7ff,stroke:#4338ca,color:#312e81
style D1 fill:#d1fae5,stroke:#059669,color:#064e3b
style D2 fill:#d1fae5,stroke:#059669,color:#064e3b
style D3 fill:#d1fae5,stroke:#059669,color:#064e3b
ForceFieldOperator¤
Computes potential energy and forces for a system of particles using classical pairwise potentials.
Quick Start¤
from flax import nnx
import jax
import jax.numpy as jnp
from diffbio.operators.molecular_dynamics import (
ForceFieldOperator,
ForceFieldConfig,
PotentialType,
create_force_field,
)
# Create force field operator
force_field = create_force_field(
potential_type=PotentialType.LENNARD_JONES,
sigma=1.0, # Particle diameter
epsilon=1.0, # Well depth
box_size=10.0, # Periodic box size
)
# Generate random particle positions
n_particles = 20
dim = 3
key = jax.random.PRNGKey(0)
positions = jax.random.uniform(key, (n_particles, dim), minval=0, maxval=10.0)
# Compute energy and forces
data = {"positions": positions}
result, state, metadata = force_field.apply(data, {}, None)
energy = result["energy"] # Scalar total energy
forces = result["forces"] # (n_particles, dim) force vectors
Configuration¤
| Parameter | Type | Default | Description |
|---|---|---|---|
potential_type |
str | "lennard_jones" | Potential type: "lennard_jones", "morse", "soft_sphere" |
sigma |
float | 1.0 | Length scale parameter (particle diameter) |
epsilon |
float | 1.0 | Energy scale parameter (well depth) |
cutoff |
float | None | 2.5 | Cutoff distance (units of sigma). None for no cutoff |
box_size |
float | None | None | Periodic box size. None for free boundaries |
alpha |
float | 5.0 | Morse potential width parameter |
Supported Potentials¤
Lennard-Jones (12-6)¤
Standard potential for van der Waals interactions:
lj_field = create_force_field(
potential_type=PotentialType.LENNARD_JONES,
sigma=1.0,
epsilon=1.0,
cutoff=2.5,
)
Soft Sphere¤
Purely repulsive potential:
soft_field = create_force_field(
potential_type=PotentialType.SOFT_SPHERE,
sigma=1.0,
epsilon=1.0,
)
Morse Potential¤
For bonded interactions with anharmonicity:
morse_field = create_force_field(
potential_type=PotentialType.MORSE,
sigma=1.0, # Equilibrium distance
epsilon=1.0, # Well depth
alpha=5.0, # Width parameter
)
Input/Output Formats¤
Input
| Key | Shape | Description |
|---|---|---|
positions |
(n_particles, dim) or (batch, n_particles, dim) | Particle positions |
Output
| Key | Shape | Description |
|---|---|---|
positions |
same as input | Original positions |
energy |
() or (batch,) | Total potential energy |
forces |
same as positions | Force vectors (F = -∇E) |
MDIntegratorOperator¤
Evolves particle positions and velocities over time using molecular dynamics integration.
Quick Start¤
from diffbio.operators.molecular_dynamics import (
MDIntegratorOperator,
MDIntegratorConfig,
create_integrator,
create_verlet_integrator,
)
# Create velocity Verlet integrator
integrator = create_verlet_integrator(
dt=0.001, # Time step
n_steps=1000, # Number of steps
box_size=10.0, # Periodic box size
sigma=1.0, # LJ sigma
epsilon=1.0, # LJ epsilon
)
# Initial conditions
n_particles = 20
dim = 3
key = jax.random.PRNGKey(0)
key1, key2 = jax.random.split(key)
positions = jax.random.uniform(key1, (n_particles, dim), minval=2, maxval=8.0)
velocities = jax.random.normal(key2, (n_particles, dim)) * 0.1
# Run simulation
data = {"positions": positions, "velocities": velocities}
result, state, metadata = integrator.apply(data, {}, None)
final_positions = result["positions"] # Final positions
final_velocities = result["velocities"] # Final velocities
trajectory = result["trajectory"] # (n_steps+1, n_particles, dim)
Configuration¤
| Parameter | Type | Default | Description |
|---|---|---|---|
integrator_type |
str | "velocity_verlet" | Integrator: "velocity_verlet" (NVE), "nvt_langevin" (NVT) |
dt |
float | 0.001 | Time step |
n_steps |
int | 100 | Number of integration steps |
box_size |
float | None | 10.0 | Periodic box size |
potential_type |
str | "lennard_jones" | Potential type |
sigma |
float | 1.0 | Potential length scale |
epsilon |
float | 1.0 | Potential energy scale |
mass |
float | 1.0 | Particle mass (uniform) |
kT |
float | 1.0 | Thermal energy (for Langevin) |
gamma |
float | 1.0 | Friction coefficient (for Langevin) |
Integrator Types¤
Velocity Verlet (NVE)¤
Symplectic integrator for microcanonical ensemble (constant energy):
Langevin Dynamics (NVT)¤
Stochastic integrator for canonical ensemble (constant temperature):
nvt_integrator = create_integrator(
integrator_type="nvt_langevin",
dt=0.001,
n_steps=1000,
kT=1.0, # Temperature
gamma=1.0, # Friction
)
Input/Output Formats¤
Input
| Key | Shape | Description |
|---|---|---|
positions |
(n_particles, dim) | Initial particle positions |
velocities |
(n_particles, dim) | Initial particle velocities |
Output
| Key | Shape | Description |
|---|---|---|
positions |
(n_particles, dim) | Final particle positions |
velocities |
(n_particles, dim) | Final particle velocities |
trajectory |
(n_steps+1, n_particles, dim) | Position trajectory |
Gradient-Based Optimization¤
Optimizing Initial Conditions¤
import optax
from flax import nnx
integrator = create_verlet_integrator(dt=0.001, n_steps=100, box_size=10.0)
def loss_fn(positions, velocities):
"""Loss based on final energy."""
result, _, _ = integrator.apply(
{"positions": positions, "velocities": velocities},
{}, None
)
# Example: minimize distance from target configuration
target = jnp.zeros_like(result["positions"])
return jnp.mean((result["positions"] - target) ** 2)
# Compute gradients
grads = jax.grad(loss_fn, argnums=0)(positions, velocities)
Training with Force Field Parameters¤
from flax import nnx
# Create learnable force field
config = ForceFieldConfig(potential_type="lennard_jones", box_size=10.0)
force_field = ForceFieldOperator(config, rngs=nnx.Rngs(42))
# Add learnable parameter (e.g., epsilon)
force_field.epsilon = nnx.Param(jnp.array(1.0))
optimizer = optax.adam(1e-3)
opt_state = optimizer.init(nnx.state(force_field, nnx.Param))
def loss_fn(model, positions, target_energy):
result, _, _ = model.apply({"positions": positions}, {}, None)
return (result["energy"] - target_energy) ** 2
@nnx.jit
def train_step(model, opt_state, positions, target):
loss, grads = nnx.value_and_grad(loss_fn)(model, positions, target)
params = nnx.state(model, nnx.Param)
updates, opt_state = optimizer.update(grads, opt_state, params)
nnx.update(model, optax.apply_updates(params, updates))
return loss, opt_state
Boundary Conditions¤
Periodic Boundaries¤
For simulating bulk systems:
Free Boundaries¤
For isolated systems (e.g., molecules, clusters):
Performance Tips¤
- JIT Compilation: All operators are JIT-compatible for fast execution
- Batched Processing: Force field supports batched positions
- Pre-computed Functions: Energy/force functions are created once at initialization
- Use scan for trajectories: Integrator uses
jax.lax.scaninternally for efficiency
Use Cases¤
| Application | Description |
|---|---|
| Protein folding | Energy minimization and dynamics |
| Materials science | Crystal structure optimization |
| Drug design | Binding energy calculations |
| Coarse-grained models | Efficient simulation of large systems |
| Force field fitting | Parameter optimization from ab initio data |
References¤
-
Schoenholz, S. S. & Cubuk, E. D. (2020). "JAX, M.D.: A Framework for Differentiable Physics." NeurIPS 2020.
-
Doerr, S. et al. (2021). "TorchMD: A Deep Learning Framework for Molecular Simulations." JCTC 17(4), 2355-2363.
-
Frenkel, D. & Smit, B. (2002). "Understanding Molecular Simulation." Academic Press.
Next Steps¤
- See Protein Structure Operators for secondary structure prediction
- Explore Statistical Operators for analysis methods
- Check Preprocessing Operators for data preparation