Skip to content

Commit

Permalink
Update get_model_params to use sgp for large datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
itskalvik committed Dec 26, 2024
1 parent 51100dc commit 1f40218
Showing 1 changed file with 19 additions and 5 deletions.
24 changes: 19 additions & 5 deletions sgptools/utils/gpflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import numpy as np
import matplotlib.pyplot as plt

from .misc import get_inducing_pts


def plot_loss(losses, save_file=None):
"""Helper function to plot the training loss
Expand Down Expand Up @@ -52,7 +54,8 @@ def get_model_params(X_train, y_train,
kernel=None,
return_gp=False,
**kwargs):
"""Train a GP on the given training set
"""Train a GP on the given training set.
Trains a sparse GP if the training set is larger than 1000 samples.
Args:
X_train (ndarray): (n, d); Training set inputs
Expand All @@ -79,12 +82,23 @@ def get_model_params(X_train, y_train,
kernel = gpflow.kernels.SquaredExponential(lengthscales=lengthscales,
variance=variance)

gpr = gpflow.models.GPR(data=(X_train, y_train),
kernel=kernel,
noise_variance=noise_variance)
if len(X_train) <= 1500:
gpr = gpflow.models.GPR(data=(X_train, y_train),
kernel=kernel,
noise_variance=noise_variance)
trainable_variables=gpr.trainable_variables
else:
inducing_pts = get_inducing_pts(X_train, 500)
gpr = gpflow.models.SGPR(data=(X_train, y_train),
kernel=kernel,
inducing_variable=inducing_pts,
noise_variance=noise_variance)
trainable_variables=gpr.trainable_variables[1:]

if max_steps > 0:
loss = optimize_model(gpr, max_steps=max_steps, lr=lr, **kwargs)
loss = optimize_model(gpr, max_steps=max_steps, lr=lr,
trainable_variables=trainable_variables,
**kwargs)
else:
loss = 0

Expand Down

0 comments on commit 1f40218

Please # to comment.