Skip to content

Commit

Permalink
save as csv
Browse files Browse the repository at this point in the history
  • Loading branch information
kashif committed Sep 20, 2024
1 parent bc9ad40 commit 277056e
Showing 1 changed file with 29 additions and 4 deletions.
33 changes: 29 additions & 4 deletions examples/anomaly_detection_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import torch.nn.functional as F

from tqdm import tqdm
import pandas as pd

from gluonts.torch import DeepAREstimator
from gluonts.dataset.repository import get_dataset
from gluonts.itertools import select
Expand Down Expand Up @@ -177,8 +179,31 @@ def main(args):
distr = model.output_distribution(params)
# stack the batch_anomalies along the prediction length dimension
anomalies.append(torch.stack(batch_anomalies, dim=1))

anomalies = torch.cat(anomalies, dim=0)
anomalies = torch.cat(anomalies, dim=0).cpu().numpy()

# save as csv
all_dates = []
all_flags = []
all_targets = []
for i, (entry, flags) in enumerate(zip(dataset.test, anomalies)):
start_date = entry["start"].to_timestamp()
target = entry["target"]
dates = pd.date_range(
start=start_date, periods=len(target), freq=dataset.metadata.freq
)
# take the last prediction_length dates
date_index = dates[-dataset.metadata.prediction_length :]
target_slice = target[-dataset.metadata.prediction_length :]
all_dates.append(date_index)
all_flags.append(flags.flatten().astype(bool))
all_targets.append(target_slice)

# create a dataframe with the date_index and the flags
anomaly_df = pd.DataFrame(
{"date": all_dates, "is_anomaly": all_flags, "target": all_targets}
)
anomaly_df.set_index("date", inplace=True)
anomaly_df.to_csv(f"anomalies_{args.dataset}.csv")


if __name__ == "__main__":
Expand All @@ -192,7 +217,7 @@ def main(args):
"--context_length", type=int, default=None, help="Context length"
)
parser.add_argument(
"--max_epochs", type=int, default=3, help="Maximum number of epochs"
"--max_epochs", type=int, default=10, help="Maximum number of epochs"
)
parser.add_argument(
"--batch_size", type=int, default=32, help="Batch size"
Expand All @@ -202,7 +227,7 @@ def main(args):

if args.context_length is None:
args.context_length = (
get_dataset(args.dataset).metadata.prediction_length * 4
get_dataset(args.dataset).metadata.prediction_length * 10
)

main(args)

0 comments on commit 277056e

Please # to comment.