Source code for alabi.cache_utils
"""
:py:mod:`cache_utils.py`
-------------------------------------
"""
import numpy as np
import pickle
import os
from . import parallel_utils
__all__ = ["load_pickle",
"load_model_cache",
"write_report_gp",
"write_report_emcee",
"write_report_dynesty"]
[docs]
def load_pickle(savedir, fname="surrogate_model.pkl"):
file = os.path.join(savedir, fname)
with open(file, "rb") as f:
sm = pickle.load(f)
return sm
[docs]
def load_model_cache(savedir):
"""
MPI-safe model loading that prevents file corruption.
:param savedir: Directory containing the model cache
:returns: Loaded surrogate model
"""
# Check if we're in an MPI environment
if parallel_utils.is_mpi_active():
try:
from mpi4py import MPI
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
except ImportError:
rank = 0
else:
rank = 0
# Only rank 0 loads the model
if rank == 0:
try:
sm = load_pickle(savedir)
except Exception as e:
print(f"Rank {rank}: Failed to load model cache: {e}")
raise # Re-raise the exception to properly handle the error
else:
sm = load_pickle(savedir)
# Broadcast the model to all ranks if using MPI
if parallel_utils.is_mpi_active():
try:
from mpi4py import MPI
comm = MPI.COMM_WORLD
sm = comm.bcast(sm, root=0)
if rank != 0:
print(f"Rank {rank}: Received model from rank 0")
except ImportError:
pass
return sm
[docs]
def write_report_gp(self, file):
# get hyperparameter names and values
hp_name = self.gp.get_parameter_names()
hp_vect = self.gp.get_parameter_vector()
# print model summary to human-readable text file
lines = f"==================================================================\n"
lines += f"GP summary \n"
lines += f"==================================================================\n\n"
report_vars = {"Kernel": "kernel_name",
"Function bounds": "bounds",
"fit mean": "fit_mean",
"fit amplitude": "fit_amp",
"fit white_noise": "fit_white_noise",
"GP white noise": "white_noise",
"Hyperparameter bounds": "hp_bounds",
"Active learning algorithm": "algorithm",
"Number of total training samples": "ntrain",
"Number of initial training samples": "ninit_train",
"Number of active training samples": "nactive",
"Number of test samples": "ntest",
}
lines += f"Configuration: \n"
lines += f"-------------- \n"
for key in report_vars.keys():
if hasattr(self, report_vars[key]):
lines += f"{key}: {getattr(self, report_vars[key])} \n"
lines += "\n"
lines += f"Results: \n"
lines += f"-------- \n"
lines += f"GP final hyperparameters: \n"
for ii in range(len(hp_name)):
lines += f" [{hp_name[ii]}] \t{hp_vect[ii]} \n"
lines += "\n"
if hasattr(self, 'train_runtime'):
lines += f"Active learning train runtime (s): {np.round(self.train_runtime)} \n\n"
if hasattr(self, 'training_results'):
lines += f"Final test error (MSE): {self.training_results['test_mse'][-1]} \n\n"
summary = open(file+".txt", "w")
summary.write(lines)
summary.close()
[docs]
def write_report_emcee(self, file):
# compute summary statistics
means = np.mean(self.emcee_samples, axis=0)
stds = np.std(self.emcee_samples, axis=0)
lines = f"==================================================================\n"
lines += f"emcee summary \n"
lines += f"==================================================================\n\n"
lines += f"Configuration: \n"
lines += f"-------------- \n"
lines += f"Number of walkers: {self.nwalkers} \n"
lines += f"Number of steps per walker: {self.nsteps} \n\n"
lines += f"Results: \n"
lines += f"-------- \n"
lines += "Mean acceptance fraction: {0:.3f} \n".format(self.acc_frac)
lines += "Mean autocorrelation time: {0:.3f} steps \n".format(self.autcorr_time)
lines += f"Burn: {self.iburn} \n"
lines += f"Thin: {self.ithin} \n"
lines += f"Total burned, thinned, flattened samples: {self.emcee_samples.shape[0]} \n\n"
lines += f"emcee runtime (s): {np.round(self.emcee_runtime)} \n\n"
lines += f"Summary statistics: \n"
for ii in range(self.ndim):
lines += f"{self.labels[ii]} = {means[ii]} +/- {stds[ii]} \n"
lines += "\n"
summary = open(file+".txt", "a")
summary.write(lines)
summary.close()
[docs]
def write_report_dynesty(self, file):
# compute summary statistics
means = np.mean(self.dynesty_samples, axis=0)
stds = np.std(self.dynesty_samples, axis=0)
lines = f"==================================================================\n"
lines += f"dynesty summary \n"
lines += f"==================================================================\n\n"
lines += f"Configuration: \n"
lines += f"-------------- \n"
lines += f"Results: \n"
lines += f"-------- \n"
lines += f"Total weighted samples: {self.dynesty_samples.shape[0]} \n\n"
lines += f"Dynesty runtime (s): {np.round(self.dynesty_runtime)} \n\n"
lines += f"Summary statistics: \n"
for ii in range(self.ndim):
lines += f"{self.labels[ii]} = {means[ii]} +/- {stds[ii]} \n"
lines += "\n"
summary = open(file+".txt", "a")
summary.write(lines)
summary.close()