Skip to content

Commit

Permalink
add calculation of population ratio
Browse files Browse the repository at this point in the history
  • Loading branch information
aoymt committed Feb 22, 2025
1 parent 3d10b1f commit 35a42dd
Showing 1 changed file with 44 additions and 2 deletions.
46 changes: 44 additions & 2 deletions src/odatse/algorithm/pamc.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def _initialize(self) -> None:
self.fx_from_reset = np.zeros((self.resampling_interval, self.nwalkers))
self.naccepted_from_reset = np.zeros((self.resampling_interval, 2), dtype=int)
self.acceptance_ratio = np.zeros(numT)
self.pr_list = np.zeros(numT)

self._show_parameters()

Expand Down Expand Up @@ -300,6 +301,11 @@ def _run(self) -> None:
self.logweights[iwalker],
*self.x[iwalker,:])

# calculate population ratio
pr = self._calc_population_ratio()
self.pr_list[Tindex] = pr


if Tindex == numT - 1:
break

Expand Down Expand Up @@ -616,6 +622,30 @@ def _resample_fixed(self, weights: np.ndarray) -> None:
self.x = self.node_coordinates[self.inode, :]
self.fx = fxs[new_index]

def _calc_population_ratio(self):
if self.mpisize > 1:
from mpi4py import MPI
max_log_weight = self.mpicomm.allreduce(np.max(self.logweights), op=MPI.MAX)

buf = [
np.sum(np.exp(self.logweights - max_log_weight)),
np.sum(np.exp(self.logweights - max_log_weight)**2),
]
buf_sum = self.mpicomm.allreduce(buf, op=MPI.SUM)

sum_weight = buf_sum[0]
sum_weight_sq = buf_sum[1]

else:
max_log_weight = np.max(self.logweights)

sum_weight = np.sum(np.exp(self.logweights - max_log_weight))
sum_weight_sq = np.sum(np.exp(self.logweights - max_log_weight)**2)

pr = sum_weight ** 2 / sum_weight_sq

return pr

def _prepare(self) -> None:
"""
Prepare the algorithm for execution.
Expand Down Expand Up @@ -706,6 +736,17 @@ def _post(self) -> None:
f.write(f" {self.logZs[i]}")
f.write(f" {self.acceptance_ratio[i]}")
f.write("\n")

with open("pr.txt", "w") as f:
f.write("# $1: Tindex\n")
f.write("# $2: 1/T\n")
f.write("# $3: population ratio\n")
for i in range(len(self.betas)):
f.write(f"{i}")
f.write(f" {self.betas[i]}")
f.write(f" {self.pr_list[i]}")
f.write("\n")

return {
"x": best_x[best_rank],
"fx": best_fx[best_rank],
Expand Down Expand Up @@ -761,6 +802,7 @@ def _save_state(self, filename) -> None:
"fx_from_reset": self.fx_from_reset,
"naccepted_from_reset": self.naccepted_from_reset,
"acceptance_ratio": self.acceptance_ratio,
"pr_list": self.pr_list,
}
self._save_data(data, filename)

Expand Down Expand Up @@ -850,6 +892,7 @@ def _load_state(self, filename, mode="resume", restore_rng=True):
self.nreplicas = np.full(numT, nreplicas)
self.populations = np.zeros((numT, self.nwalkers), dtype=int)
self.acceptance_ratio = np.zeros(numT)
self.pr_list = np.zeros(numT)

self.logZ = data["logZ"]
self.logZs[0:len(data["logZs"])] = data["logZs"]
Expand All @@ -866,5 +909,4 @@ def _load_state(self, filename, mode="resume", restore_rng=True):
self.walker_ancestors = data["walker_ancestors"]
self.naccepted_from_reset = data["naccepted_from_reset"]
self.acceptance_ratio[0:len(data["acceptance_ratio"])] = data["acceptance_ratio"]


self.pr_list[0:len(data["pr_list"])] = data["pr_list"]

0 comments on commit 35a42dd

Please # to comment.