Skip to content

Commit

Permalink
Store predictions
Browse files Browse the repository at this point in the history
  • Loading branch information
deinal committed Aug 14, 2024
1 parent 9bd5290 commit 0a09700
Show file tree
Hide file tree
Showing 7 changed files with 140 additions and 23 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ Mediterranean analysis
```
python prepare_states.py -d data/mediterranean/raw/analysis -o data/mediterranean/samples/train -n 6 -p ana_data -s 2022-01-01 -e 2024-04-30
python prepare_states.py -d data/mediterranean/raw/analysis -o data/mediterranean/samples/val -n 6 -p ana_data -s 2024-05-01 -e 2024-06-30
python prepare_states.py -d data/mediterranean/raw/analysis -o data/mediterranean/samples/test -n 17 -p ana_data -s 2024-07-22 -e 2024-08-12 --forecast
```

ERA5
Expand All @@ -63,7 +64,7 @@ python prepare_states.py -d data/mediterranean/raw/era5 -o data/mediterranean/sa

Forecast data
```
python prepare_states.py -d data/mediterranean/raw/analysis -o data/mediterranean/samples/test -n 16 -p ana_data -s 2024-06-30 -e 2024-08-10 --forecast
python prepare_states.py -d data/mediterranean/raw/forecast -o data/mediterranean/samples/test -p for_data -s 2024-07-24 -e 2024-08-01 --forecast
python prepare_states.py -d data/mediterranean/raw/ens -o data/mediterranean/samples/test -p ens_forcing -s 2024-07-01 -e 2024-08-11 --forecast
python prepare_states.py -d data/mediterranean/raw/aifs -o data/mediterranean/samples/test -p aifs_forcing -s 2024-06-01 -e 2024-08-11 --forecast
```
Expand Down
4 changes: 3 additions & 1 deletion download_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,8 @@ def download_forecast(

filename = f"{path_prefix}/{start_date.strftime('%Y%m%d')}.npy"

initial_date = start_date - timedelta(days=2)

all_data = []
for dataset_id, variables in datasets.items():
# Load ocean physics dataset for all dates at once
Expand All @@ -270,7 +272,7 @@ def download_forecast(
dataset_part="default",
service="arco-geo-series",
variables=variables,
start_datetime=start_date.strftime("%Y-%m-%dT00:00:00"),
start_datetime=initial_date.strftime("%Y-%m-%dT00:00:00"),
end_datetime=end_date.strftime("%Y-%m-%dT00:00:00"),
minimum_depth=constants.DEPTHS[0],
maximum_depth=constants.DEPTHS[-1],
Expand Down
6 changes: 4 additions & 2 deletions neural_lam/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@

# Log prediction error for these lead times
VAL_STEP_LOG_ERRORS = np.array([1, 2, 3, 4])
TEST_STEP_LOG_ERRORS = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14])
TEST_STEP_LOG_ERRORS = np.array(
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
)

# Sample lengths
SAMPLE_LEN = {
"train": 6,
"val": 6,
"test": 16,
"test": 17,
}

# Log these metrics to wandb as scalar values for
Expand Down
44 changes: 44 additions & 0 deletions neural_lam/models/ar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,12 @@ def __init__(self, args):
self.save_hyperparameters()
self.optimizer = args.optimizer
self.lr = args.lr
self.batch_size = args.batch_size
self.epochs = args.epochs
self.scheduler = args.scheduler
self.initial_lr = args.initial_lr
self.warmup_epochs = args.warmup_epochs
self.store_pred = args.store_pred

# Load static features for grid/data
static_data_dict = utils.load_static_data(args.dataset)
Expand Down Expand Up @@ -99,6 +101,9 @@ def __init__(self, args):
# For storing spatial loss maps during evaluation
self.spatial_loss_maps = []

# For storing predictions under sample names
self.sample_names = []

def configure_optimizers(self):
if self.optimizer == "adamw":
opt = torch.optim.AdamW(
Expand Down Expand Up @@ -166,6 +171,12 @@ def expand_to_batch(x, batch_size):
"""
return x.unsqueeze(0).expand(batch_size, -1, -1)

def set_sample_names(self, dataset):
"""
Set sample names for evaluation
"""
self.sample_names = dataset.sample_names

def predict_step(self, prev_state, prev_prev_state, forcing):
"""
Step state one step ahead using prediction model, X_{t-1}, X_t -> X_t+1
Expand Down Expand Up @@ -393,6 +404,10 @@ def test_step(self, batch, batch_idx):
self.spatial_loss_maps.append(log_spatial_losses)
# (B, N_log, num_grid_nodes)

# Store predictions
if self.store_pred:
self.store_predictions(batch_idx, prediction)

# Plot example predictions (on rank 0 only)
if (
self.trainer.is_global_zero
Expand All @@ -407,6 +422,35 @@ def test_step(self, batch, batch_idx):
batch, n_additional_examples, prediction=prediction
)

def store_predictions(self, batch_idx, prediction):
"""
Store predictions for a batch
batch_idx: index of the batch in the dataloader
prediction: (B, pred_steps, num_grid_nodes, d_f), existing prediction.
"""

sample_names = [
self.sample_names[idx]
for idx in range(
batch_idx * self.batch_size,
(batch_idx + 1) * self.batch_size,
)
]

# Rescale to original data scale
prediction_rescaled = prediction * self.data_std + self.data_mean

pred_dir = os.path.join(wandb.run.dir, "predictions")
os.makedirs(pred_dir, exist_ok=True)

# Save pred as .npy files
for i, sample_name in enumerate(sample_names):
np.save(
os.path.join(pred_dir, f"{sample_name}.npy"),
prediction_rescaled[i].cpu().numpy(),
)

def plot_examples(self, batch, n_examples, prediction=None):
"""
Plot the first n_examples forecasts from batch
Expand Down
6 changes: 5 additions & 1 deletion neural_lam/weather_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,11 @@ def __init__(
else (
"ana_data_*.npy"
if data_subset == "analysis"
else "*_data_*.npy"
else (
"for_data_*.npy"
if data_subset == "forecast" and split == "test"
else "*_data_*.npy"
)
)
)
sample_paths = glob.glob(
Expand Down
91 changes: 74 additions & 17 deletions prepare_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ def prepare_states(

# Process each file, concatenate with the next t-1 files
for i in range(len(files) - n_states + 1):
# Name as today's date
out_filename = f"{prefix}_{os.path.basename(files[i+1])}"
# Name as first forecasted date
out_filename = f"{prefix}_{os.path.basename(files[i + 2])}"
out_file = os.path.join(out_directory, out_filename)

if os.path.isfile(out_file):
Expand Down Expand Up @@ -132,27 +132,22 @@ def prepare_states_with_boundary(

# Process each file, concatenate with the next t-1 files
for i in range(len(files) - n_states + 1):
today = os.path.basename(files[i + 1])
out_filename = f"{prefix}_{today}"
forecast_date = os.path.basename(files[i + 2])
out_filename = f"{prefix}_{forecast_date}"
out_file = os.path.join(out_directory, out_filename)

if os.path.isfile(out_file):
continue

# Stack analysis states
state_sequence = [np.load(files[i + j]) for j in range(n_states)]
full_state = np.stack(state_sequence, axis=0)
print("full state", full_state.shape) # (n_states, N_grid, d_features)

forecast_file = files[i + 1].replace("analysis", "forecast")
forecast_data = np.load(forecast_file)
forecast_len = forecast_data.shape[0]
forecast_file = files[i + 2].replace("analysis", "forecast")
forecast_data = np.load(forecast_file)[2:]
print(
"forecast before", forecast_data.shape
) # (forecast_len, N_grid, d_features)

assert n_states >= forecast_len, "n_states less than forecast length"
extra_states = n_states - 1 - forecast_data.shape[0]
extra_states = 5
last_forecast_state = forecast_data[-1]
repeated_forecast_states = np.repeat(
last_forecast_state[np.newaxis, ...], extra_states, axis=0
Expand All @@ -162,21 +157,65 @@ def prepare_states_with_boundary(
)
print(
"forecast after", forecast_data.shape
) # (n_states - 1, N_grid, d_features)
) # (n_states - 2, N_grid, d_features)

# Concatenate preceding day analysis state with forecast data
forecast_data = np.concatenate(
(state_sequence[:1], forecast_data), axis=0
(state_sequence[:2], forecast_data), axis=0
) # (n_states, N_grid, d_features)

full_state = (
full_state * (1 - border_mask) + forecast_data * border_mask
)

np.save(out_file, full_state)
np.save(out_file, full_state.astype(np.float32))
print(f"Saved states to: {out_file}")


def prepare_forecast(in_directory, out_directory, prefix, start_date, end_date):
"""
Prepare forecast data by repeating the last state.
"""
forecast_dir = in_directory

start_dt = datetime.strptime(start_date, "%Y-%m-%d")
end_dt = datetime.strptime(end_date, "%Y-%m-%d")

os.makedirs(out_directory, exist_ok=True)

# Get files sorted by date
forecast_files = sorted(
glob(os.path.join(forecast_dir, "*.npy")),
key=lambda x: datetime.strptime(os.path.basename(x)[:8], "%Y%m%d"),
)
forecast_files = [
f
for f in forecast_files
if start_dt
<= datetime.strptime(os.path.basename(f)[:8], "%Y%m%d")
<= end_dt
]

for forecast_file in forecast_files:
# Load the current forecast data
forecast_data = np.load(forecast_file)
print(forecast_data.shape)

last_forecast_state = forecast_data[-1]
repeated_forecast_states = np.repeat(
last_forecast_state[np.newaxis, ...], repeats=5, axis=0
)
forecast_data = np.concatenate(
[forecast_data, repeated_forecast_states], axis=0
)

# Save concatenated data
out_filename = f"{prefix}_{os.path.basename(forecast_file)}"
out_file = os.path.join(out_directory, out_filename)
np.save(out_file, forecast_data)
print(f"Saved forecast to: {out_file}")


def prepare_forcing(in_directory, out_directory, prefix, start_date, end_date):
"""
Prepare atmospheric forcing data from forecasts.
Expand Down Expand Up @@ -205,6 +244,14 @@ def prepare_forcing(in_directory, out_directory, prefix, start_date, end_date):
forecast_date = datetime.strptime(
os.path.basename(forecast_file)[:8], "%Y%m%d"
)

# Get files for the pre-preceding day
prepreceding_day_file = os.path.join(
forecast_dir,
(forecast_date - timedelta(days=2)).strftime("%Y%m%d") + ".npy",
)
prepreceding_day_data = np.load(prepreceding_day_file)[0:1]

# Get files for the preceding day
preceding_day_file = os.path.join(
forecast_dir,
Expand All @@ -217,12 +264,14 @@ def prepare_forcing(in_directory, out_directory, prefix, start_date, end_date):

print(preceding_day_data.shape, current_forecast_data.shape)

prepreceding_day_data = prepreceding_day_data[:, :, :4]
preceding_day_data = preceding_day_data[:, :, :4]
current_forecast_data = current_forecast_data[:, :, :4]

# Concatenate all data along the time axis
concatenated_forcing = np.concatenate(
[preceding_day_data, current_forecast_data], axis=0
[prepreceding_day_data, preceding_day_data, current_forecast_data],
axis=0,
)

# Save concatenated data
Expand Down Expand Up @@ -303,7 +352,7 @@ def main():
args.start_date,
args.end_date,
)
else:
elif args.data_dir.endswith("analysis"):
prepare_states_with_boundary(
args.data_dir,
args.static_dir,
Expand All @@ -313,6 +362,14 @@ def main():
args.start_date,
args.end_date,
)
else:
prepare_forecast(
args.data_dir,
args.out_dir,
args.prefix,
args.start_date,
args.end_date,
)
else:
prepare_states(
args.data_dir,
Expand Down
9 changes: 8 additions & 1 deletion train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def main():
parser.add_argument(
"--data_subset",
type=str,
choices=["analysis", "reanalysis"],
choices=["analysis", "reanalysis", "forecast"],
default=None,
help="Type of data to use: 'analysis' or 'reanalysis' (default: None)",
)
Expand Down Expand Up @@ -259,6 +259,12 @@ def main():
help="Number of example predictions to plot during evaluation "
"(default: 1)",
)
parser.add_argument(
"--store_pred",
type=int,
default=0,
help="Whether or not to store predictions (default: 0 (no))",
)
args = parser.parse_args()

# Asserts for arguments
Expand Down Expand Up @@ -391,6 +397,7 @@ def main():
shuffle=False,
num_workers=args.n_workers,
)
model.set_sample_names(eval_loader.dataset)

print(f"Running evaluation on {args.eval}")
trainer.test(model=model, dataloaders=eval_loader)
Expand Down

0 comments on commit 0a09700

Please # to comment.