diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index eb708984..33180d8a 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -15,7 +15,7 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - os: [ubuntu-latest, macOS-latest] # add windows-2019 when poetry allows installation with `-f` flag + os: [ubuntu-latest, macos-13] # add windows-2019 when poetry allows installation with `-f` flag python-version: ["3.8", "3.9", "3.10"] defaults: run: @@ -62,7 +62,7 @@ jobs: run: poetry run python -m pip install pip -U - name: Install dependencies - run: poetry install -E "github-actions graph mqf2" + run: poetry install -E "github-actions graph" # - name: Install pytorch geometric dependencies # shell: bash diff --git a/docs/requirements.txt b/docs/requirements.txt index 09435203..c68cd39e 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -7,6 +7,7 @@ lightning >=2.0.0 cloudpickle torch >=2.0,!=2.0.1 optuna >=3.1.0 +optuna-integration scipy pandas >=1.3 scikit-learn >1.2 diff --git a/poetry.lock b/poetry.lock index 1648778c..73f8c93f 100644 --- a/poetry.lock +++ b/poetry.lock @@ -302,7 +302,7 @@ files = [ name = "backports-functools-lru-cache" version = "1.6.6" description = "Backport of functools.lru_cache" -optional = true +optional = false python-versions = ">=2.6" files = [ {file = "backports.functools_lru_cache-1.6.6-py2.py3-none-any.whl", hash = "sha256:77e27d0ffbb463904bdd5ef8b44363f6cd5ef503e664b3f599a3bf5843ed37cf"}, @@ -798,7 +798,7 @@ toml = ["tomli"] name = "cpflows" version = "0.1.2" description = "Convex Potential Flows package" -optional = true +optional = false python-versions = "*" files = [ {file = "cpflows-0.1.2.tar.gz", hash = "sha256:a88f5c8f948776d0619c78bf183ab639543ef4cf8e4d91e64e1e45b13a61bbdd"}, @@ -1204,7 +1204,7 @@ tqdm = ["tqdm"] name = "future" version = "0.18.3" description = "Clean single-source support for Python 3 and 2" -optional = true +optional = false python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" files = [ {file = "future-0.18.3.tar.gz", hash = "sha256:34a17436ed1e96697a86f9de3d15a3b0be01d8bc8de9c1dffd59fb8234ed5307"}, @@ -1391,7 +1391,7 @@ protobuf = ["grpcio-tools (>=1.58.0)"] name = "h5py" version = "3.9.0" description = "Read and write HDF5 files from Python" -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "h5py-3.9.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:eb7bdd5e601dd1739698af383be03f3dad0465fe67184ebd5afca770f50df9d6"}, @@ -2701,6 +2701,26 @@ integration = ["botorch (>=0.4.0)", "catboost (>=0.26)", "catboost (>=0.26,<1.2) optional = ["boto3", "botorch", "cmaes (>=0.10.0)", "google-cloud-storage", "matplotlib (!=3.6.0)", "pandas", "plotly (>=4.9.0)", "redis", "scikit-learn (>=0.24.2)"] test = ["coverage", "fakeredis[lua]", "kaleido", "moto", "pytest", "scipy (>=1.9.2)"] +[[package]] +name = "optuna-integration" +version = "3.6.0" +description = "Integration libraries of Optuna." +optional = false +python-versions = "*" +files = [ + {file = "optuna-integration-3.6.0.tar.gz", hash = "sha256:f261c38586b22cd95639287ca694fc0f788482cfbb7bb83803caf404ce06a55a"}, + {file = "optuna_integration-3.6.0-py3-none-any.whl", hash = "sha256:e281d4902ab728b4c86a997eb01e7bc54d921ae7cff40ed8f4e083f49d37e033"}, +] + +[package.dependencies] +optuna = "*" + +[package.extras] +all = ["botorch (<0.10.0)", "catalyst", "catboost (>=0.26)", "catboost (>=0.26,<1.2)", "cma", "distributed", "fastai", "gpytorch", "lightgbm", "lightning", "mlflow", "mxnet", "pandas", "pytorch-ignite", "scikit-learn (>=0.24.2)", "scikit-optimize", "scipy (>=1.9.2)", "shap", "skorch", "tensorboard", "tensorflow", "torch", "wandb", "xgboost"] +checking = ["black", "blackdoc", "hacking", "isort", "mypy", "types-PyYAML", "types-redis", "types-setuptools", "typing-extensions (>=3.10.0.0)"] +document = ["cma", "mlflow", "pandas", "scikit-learn (>=0.24.2)", "scipy (>=1.9.2)", "sphinx", "sphinx-rtd-theme"] +test = ["coverage", "fakeredis[lua]", "pytest"] + [[package]] name = "packaging" version = "23.1" @@ -4114,7 +4134,7 @@ test = ["asv", "gmpy2", "mpmath", "pooch", "pytest", "pytest-cov", "pytest-timeo name = "seaborn" version = "0.12.2" description = "Statistical data visualization" -optional = true +optional = false python-versions = ">=3.7" files = [ {file = "seaborn-0.12.2-py3-none-any.whl", hash = "sha256:ebf15355a4dba46037dfd65b7350f014ceb1f13c05e814eda2c9f5fd731afc08"}, @@ -4437,35 +4457,35 @@ description = "Statistical computations and models for Python" optional = false python-versions = ">=3.8" files = [ - {file = "statsmodels-0.14.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:16bfe0c96a53b20fa19067e3b6bd2f1d39e30d4891ea0d7bc20734a0ae95942d"}, - {file = "statsmodels-0.14.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5a6a0a1a06ff79be8aa89c8494b33903442859add133f0dda1daf37c3c71682e"}, - {file = "statsmodels-0.14.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:77b3cd3a5268ef966a0a08582c591bd29c09c88b4566c892a7c087935234f285"}, - {file = "statsmodels-0.14.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9c64ebe9cf376cba0c31aed138e15ed179a1d128612dd241cdf299d159e5e882"}, - {file = "statsmodels-0.14.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:229b2f676b4a45cb62d132a105c9c06ca8a09ffba060abe34935391eb5d9ba87"}, - {file = "statsmodels-0.14.0-cp310-cp310-win_amd64.whl", hash = "sha256:fb471f757fc45102a87e5d86e87dc2c8c78b34ad4f203679a46520f1d863b9da"}, - {file = "statsmodels-0.14.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:582f9e41092e342aaa04920d17cc3f97240e3ee198672f194719b5a3d08657d6"}, - {file = "statsmodels-0.14.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7ebe885ccaa64b4bc5ad49ac781c246e7a594b491f08ab4cfd5aa456c363a6f6"}, - {file = "statsmodels-0.14.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b587ee5d23369a0e881da6e37f78371dce4238cf7638a455db4b633a1a1c62d6"}, - {file = "statsmodels-0.14.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0ef7fa4813c7a73b0d8a0c830250f021c102c71c95e9fe0d6877bcfb56d38b8c"}, - {file = "statsmodels-0.14.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:afe80544ef46730ea1b11cc655da27038bbaa7159dc5af4bc35bbc32982262f2"}, - {file = "statsmodels-0.14.0-cp311-cp311-win_amd64.whl", hash = "sha256:a6ad7b8aadccd4e4dd7f315a07bef1bca41d194eeaf4ec600d20dea02d242fce"}, - {file = "statsmodels-0.14.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:0eea4a0b761aebf0c355b726ac5616b9a8b618bd6e81a96b9f998a61f4fd7484"}, - {file = "statsmodels-0.14.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4c815ce7a699047727c65a7c179bff4031cff9ae90c78ca730cfd5200eb025dd"}, - {file = "statsmodels-0.14.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:575f61337c8e406ae5fa074d34bc6eb77b5a57c544b2d4ee9bc3da6a0a084cf1"}, - {file = "statsmodels-0.14.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8be53cdeb82f49c4cb0fda6d7eeeb2d67dbd50179b3e1033510e061863720d93"}, - {file = "statsmodels-0.14.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:6f7d762df4e04d1dde8127d07e91aff230eae643aa7078543e60e83e7d5b40db"}, - {file = "statsmodels-0.14.0-cp312-cp312-win_amd64.whl", hash = "sha256:fc2c7931008a911e3060c77ea8933f63f7367c0f3af04f82db3a04808ad2cd2c"}, - {file = "statsmodels-0.14.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:3757542c95247e4ab025291a740efa5da91dc11a05990c033d40fce31c450dc9"}, - {file = "statsmodels-0.14.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:de489e3ed315bdba55c9d1554a2e89faa65d212e365ab81bc323fa52681fc60e"}, - {file = "statsmodels-0.14.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:76e290f4718177bffa8823a780f3b882d56dd64ad1c18cfb4bc8b5558f3f5757"}, - {file = "statsmodels-0.14.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:71054f9dbcead56def14e3c9db6f66f943110fdfb19713caf0eb0f08c1ec03fd"}, - {file = "statsmodels-0.14.0-cp38-cp38-win_amd64.whl", hash = "sha256:d7fda067837df94e0a614d93d3a38fb6868958d37f7f50afe2a534524f2660cb"}, - {file = "statsmodels-0.14.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:1c7724ad573af26139a98393ae64bc318d1b19762b13442d96c7a3e793f495c3"}, - {file = "statsmodels-0.14.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:3b0a135f3bfdeec987e36e3b3b4c53e0bb87a8d91464d2fcc4d169d176f46fdb"}, - {file = "statsmodels-0.14.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ce28eb1c397dba437ec39b9ab18f2101806f388c7a0cf9cdfd8f09294ad1c799"}, - {file = "statsmodels-0.14.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:68b1c768dd94cc5ba8398121a632b673c625491aa7ed627b82cb4c880a25563f"}, - {file = "statsmodels-0.14.0-cp39-cp39-win_amd64.whl", hash = "sha256:8d1e3e10dfbfcd58119ba5a4d3c7d519182b970a2aebaf0b6f539f55ae16058d"}, - {file = "statsmodels-0.14.0.tar.gz", hash = "sha256:6875c7d689e966d948f15eb816ab5616f4928706b180cf470fd5907ab6f647a4"}, + {file = "statsmodels-0.14.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:43af9c0b07c9d72f275cf14ea54a481a3f20911f0b443181be4769def258fdeb"}, + {file = "statsmodels-0.14.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a16975ab6ad505d837ba9aee11f92a8c5b49c4fa1ff45b60fe23780b19e5705e"}, + {file = "statsmodels-0.14.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e278fe74da5ed5e06c11a30851eda1af08ef5af6be8507c2c45d2e08f7550dde"}, + {file = "statsmodels-0.14.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c0564d92cb05b219b4538ed09e77d96658a924a691255e1f7dd23ee338df441b"}, + {file = "statsmodels-0.14.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:5385e22e72159a09c099c4fb975f350a9f3afeb57c1efce273b89dcf1fe44c0f"}, + {file = "statsmodels-0.14.1-cp310-cp310-win_amd64.whl", hash = "sha256:0a8aae75a2e08ebd990e5fa394f8e32738b55785cb70798449a3f4207085e667"}, + {file = "statsmodels-0.14.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b69a63ad6c979a6e4cde11870ffa727c76a318c225a7e509f031fbbdfb4e416a"}, + {file = "statsmodels-0.14.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7562cb18a90a114f39fab6f1c25b9c7b39d9cd5f433d0044b430ca9d44a8b52c"}, + {file = "statsmodels-0.14.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b3abaca4b963259a2bf349c7609cfbb0ce64ad5fb3d92d6f08e21453e4890248"}, + {file = "statsmodels-0.14.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b0f727fe697f6406d5f677b67211abe5a55101896abdfacdb3f38410405f6ad8"}, + {file = "statsmodels-0.14.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:b6838ac6bdb286daabb5e91af90fd4258f09d0cec9aace78cc441cb2b17df428"}, + {file = "statsmodels-0.14.1-cp311-cp311-win_amd64.whl", hash = "sha256:709bfcef2dbe66f705b17e56d1021abad02243ee1a5d1efdb90f9bad8b06a329"}, + {file = "statsmodels-0.14.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:f32a7cd424cf33304a54daee39d32cccf1d0265e652c920adeaeedff6d576457"}, + {file = "statsmodels-0.14.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f8c30181c084173d662aaf0531867667be2ff1bee103b84feb64f149f792dbd2"}, + {file = "statsmodels-0.14.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2de2b97413913d52ad6342dece2d653e77f78620013b7705fad291d4e4266ccb"}, + {file = "statsmodels-0.14.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c3420f88289c593ba2bca33619023059c476674c160733bd7d858564787c83d3"}, + {file = "statsmodels-0.14.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:c008e16096f24f0514e53907890ccac6589a16ad6c81c218f2ee6752fdada555"}, + {file = "statsmodels-0.14.1-cp312-cp312-win_amd64.whl", hash = "sha256:bc0351d279c4e080f0ce638a3d886d312aa29eade96042e3ba0a73771b1abdfb"}, + {file = "statsmodels-0.14.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:bf293ada63b2859d95210165ad1dfcd97bd7b994a5266d6fbeb23659d8f0bf68"}, + {file = "statsmodels-0.14.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:44ca8cb88fa3d3a4ffaff1fb8eb0e98bbf83fc936fcd9b9eedee258ecc76696a"}, + {file = "statsmodels-0.14.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0d5373d176239993c095b00d06036690a50309a4e00c2da553b65b840f956ae6"}, + {file = "statsmodels-0.14.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a532dfe899f8b6632cd8caa0b089b403415618f51e840d1817a1e4b97e200c73"}, + {file = "statsmodels-0.14.1-cp38-cp38-win_amd64.whl", hash = "sha256:4fe0a60695952b82139ae8750952786a700292f9e0551d572d7685070944487b"}, + {file = "statsmodels-0.14.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:04293890f153ffe577e60a227bd43babd5f6c1fc50ea56a3ab1862ae85247a95"}, + {file = "statsmodels-0.14.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:3e70a2e93d54d40b2cb6426072acbc04f35501b1ea2569f6786964adde6ca572"}, + {file = "statsmodels-0.14.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ab3a73d16c0569adbba181ebb967e5baaa74935f6d2efe86ac6fc5857449b07d"}, + {file = "statsmodels-0.14.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eefa5bcff335440ee93e28745eab63559a20cd34eea0375c66d96b016de909b3"}, + {file = "statsmodels-0.14.1-cp39-cp39-win_amd64.whl", hash = "sha256:bc43765710099ca6a942b5ffa1bac7668965052542ba793dd072d26c83453572"}, + {file = "statsmodels-0.14.1.tar.gz", hash = "sha256:2260efdc1ef89f39c670a0bd8151b1d0843567781bcafec6cda0534eb47a94f6"}, ] [package.dependencies] @@ -4487,7 +4507,7 @@ docs = ["ipykernel", "jupyter-client", "matplotlib", "nbconvert", "nbformat", "n name = "subprocess32" version = "3.5.4" description = "A backport of the subprocess module from Python 3 for use on 2.x." -optional = true +optional = false python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*, <4" files = [ {file = "subprocess32-3.5.4-cp27-cp27m-macosx_10_6_intel.whl", hash = "sha256:88e37c1aac5388df41cc8a8456bb49ebffd321a3ad4d70358e3518176de3a56b"}, @@ -4712,7 +4732,7 @@ visual = ["SciencePlots (>=2.0.0)", "matplotlib (>=3.2.0)"] name = "torchvision" version = "0.17.1" description = "image and video datasets and models for torch deep learning" -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "torchvision-0.17.1-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:06418880212b66e45e855dd39f536e7fd48b4e6b034a11dd9fe9e2384afb51ec"}, @@ -5055,9 +5075,8 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [extras] github-actions = ["pytest-github-actions-annotate-failures"] graph = ["networkx"] -mqf2 = ["cpflows"] [metadata] lock-version = "2.0" python-versions = ">=3.8,<3.11" -content-hash = "acf0ee98a7ed5f9c84477905279c6eb686064b786d51ba5e5dab7501cb3252f9" +content-hash = "63e6eae67fe328b80bfa092d1ce44663d22a9bf90afc756b98b04232ba376b5b" diff --git a/pyproject.toml b/pyproject.toml index 826fcf26..c0d5064a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,15 +58,18 @@ python = ">=3.8,<3.11" torch = "^2.0.0,!=2.0.1" lightning = "^2.0.0" optuna = "^3.1.0" +optuna-integration="*" scipy = "^1.8" pandas = ">=1.3.0,<=3.0.0" +pyarrow = "*" scikit-learn = "^1.2" matplotlib = "*" +tensorboard = "^2.12.1" +cpflows = "^0.1.2" statsmodels = "*" pytest-github-actions-annotate-failures = { version = "*", optional = true } networkx = { version = "^3.0.0", optional = true } -cpflows = { version = "^0.1.2", optional = true } fastapi = ">=0.80" pytorch-optimizer = "^2.5.1" @@ -108,7 +111,6 @@ pandoc = "^2.3" [tool.poetry.extras] # extras github-actions = ["pytest-github-actions-annotate-failures"] graph = ["networkx"] -mqf2 = ["cpflows"] [tool.poetry-dynamic-versioning] enable = true diff --git a/pytorch_forecasting/data/examples.py b/pytorch_forecasting/data/examples.py index 80409170..03ae566a 100644 --- a/pytorch_forecasting/data/examples.py +++ b/pytorch_forecasting/data/examples.py @@ -3,10 +3,10 @@ """ from pathlib import Path +from urllib.request import urlretrieve import numpy as np import pandas as pd -import requests BASE_URL = "https://github.com/jdb78/pytorch-forecasting/raw/master/examples/data/" @@ -28,9 +28,7 @@ def _get_data_by_filename(fname: str) -> Path: # check if file exists - download if necessary if not full_fname.exists(): url = BASE_URL + fname - download = requests.get(url, allow_redirects=True) - with open(full_fname, "wb") as file: - file.write(download.content) + urlretrieve(url, full_fname) return full_fname diff --git a/pytorch_forecasting/models/base_model.py b/pytorch_forecasting/models/base_model.py index 7ee1524e..2f7e6e5b 100644 --- a/pytorch_forecasting/models/base_model.py +++ b/pytorch_forecasting/models/base_model.py @@ -18,7 +18,7 @@ from lightning.pytorch.utilities.parsing import AttributeDict, get_init_args import matplotlib.pyplot as plt import numpy as np -from numpy.lib.function_base import iterable +from numpy import iterable import pandas as pd import pytorch_optimizer from pytorch_optimizer import Ranger21