diff --git a/sgptools/utils/gpflow.py b/sgptools/utils/gpflow.py index 98f695e..3fe66a9 100644 --- a/sgptools/utils/gpflow.py +++ b/sgptools/utils/gpflow.py @@ -21,6 +21,8 @@ import numpy as np import matplotlib.pyplot as plt +from .misc import get_inducing_pts + def plot_loss(losses, save_file=None): """Helper function to plot the training loss @@ -52,7 +54,8 @@ def get_model_params(X_train, y_train, kernel=None, return_gp=False, **kwargs): - """Train a GP on the given training set + """Train a GP on the given training set. + Trains a sparse GP if the training set is larger than 1000 samples. Args: X_train (ndarray): (n, d); Training set inputs @@ -79,12 +82,23 @@ def get_model_params(X_train, y_train, kernel = gpflow.kernels.SquaredExponential(lengthscales=lengthscales, variance=variance) - gpr = gpflow.models.GPR(data=(X_train, y_train), - kernel=kernel, - noise_variance=noise_variance) + if len(X_train) <= 1500: + gpr = gpflow.models.GPR(data=(X_train, y_train), + kernel=kernel, + noise_variance=noise_variance) + trainable_variables=gpr.trainable_variables + else: + inducing_pts = get_inducing_pts(X_train, 500) + gpr = gpflow.models.SGPR(data=(X_train, y_train), + kernel=kernel, + inducing_variable=inducing_pts, + noise_variance=noise_variance) + trainable_variables=gpr.trainable_variables[1:] if max_steps > 0: - loss = optimize_model(gpr, max_steps=max_steps, lr=lr, **kwargs) + loss = optimize_model(gpr, max_steps=max_steps, lr=lr, + trainable_variables=trainable_variables, + **kwargs) else: loss = 0