-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathlj.py
63 lines (50 loc) · 1.95 KB
/
lj.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import jax
import jax.numpy as jnp
from glp.neighborlist import quadratic_neighbor_list
from glp.system import atoms_to_system
from glp.graph import system_to_graph
from glp.utils import cast, distance
from .potential import Potential
def lennard_jones(sigma=2.0, epsilon=1.5, cutoff=10.0, onset=6.0):
# we assume double counting, so 4*epsilon/2 is the prefactor
factor = cast(2 * epsilon)
sigma = cast(sigma)
cutoff2 = cast(cutoff**2)
onset2 = cast(onset**2)
zero = cast(0.0)
one = cast(1.0)
def pairwise_energy_fn(dr):
inverse_r = sigma / dr
inverse_r6 = inverse_r ** cast(6.0)
inverse_r12 = inverse_r6 * inverse_r6
return factor * (inverse_r12 - inverse_r6)
def cutoff_fn(dr):
# inspired by jax-md, which in turns uses HOOMD-BLUE
distance2 = dr ** cast(2.0)
# in between onset and infinity:
# either our mollifier or zero
after_onset = jnp.where(
distance2 < cutoff2,
(cutoff2 - distance2) ** cast(2.0)
* (cutoff2 + cast(2.0) * distance2 - cast(3.0) * onset2)
/ (cutoff2 - onset2) ** cast(3.0),
zero,
)
# do nothing before onset, then mollify
return jnp.where(
distance2 < onset2,
one,
after_onset,
)
def pair_lj(dr):
return cutoff_fn(dr) * pairwise_energy_fn(dr)
def lennard_jones_fn(graph):
distances = distance(graph.edges)
contributions = jax.vmap(pair_lj)(distances)
out = contributions * graph.mask
# a = cast(0.5) * jax.ops.segment_sum(out, graph.centers, graph.nodes.shape[0], indices_are_sorted=True)
# b = cast(0.5) * jax.ops.segment_sum(out, graph.others, graph.nodes.shape[0], indices_are_sorted=False)
return jax.ops.segment_sum(
out, graph.centers, graph.nodes.shape[0], indices_are_sorted=True
)
return Potential(lennard_jones_fn, cutoff)