From 30f827586319e3c190e0d7ab655887331d45a0db Mon Sep 17 00:00:00 2001 From: GrantPerkins Date: Tue, 13 Sep 2022 14:30:56 +0200 Subject: [PATCH] Initial commit --- main.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/main.py b/main.py index d9d3eab..084b0dc 100644 --- a/main.py +++ b/main.py @@ -38,7 +38,7 @@ validation_split = 0.20 test_split = 0.20 -dataloader = TimeSeriesDataLoader(X, y, validation_split=validation_split, test_split=test_split, period=1000, batch_size=10) +dataloader = TimeSeriesDataLoader(X, y, validation_split=validation_split, test_split=test_split, period=1000, batch_size=32) model = SimpleLSTM(X.shape[1], 100, 3, batch_first=True, dropout=0.5) if cuda_available: @@ -61,6 +61,8 @@ forecast, _ = model.forecast(X.unsqueeze(0)) forecast = forecast.flatten().cpu().detach().numpy() +print(f"Forecast, {forecast}") +print(f"Expected, {y}") fig, axs = plt.subplots(ncols=2, figsize=(10, 5))