diff --git a/.gitignore b/.gitignore index e1da78f..2f83499 100755 --- a/.gitignore +++ b/.gitignore @@ -92,7 +92,7 @@ instance/ .scrapy # Sphinx documentation -docs/build/ +docs/_build/ docs/modules.rst docs/atm.rst docs/atm.*.rst @@ -136,3 +136,6 @@ ENV/ # mypy .mypy_cache/ + +# pid +*.pid diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 0000000..c0cc0b3 --- /dev/null +++ b/.travis.yml @@ -0,0 +1,27 @@ +# Config file for automatic testing at travis-ci.org +language: python +python: + - 2.7 + - 3.5 + - 3.6 + +# Command to install dependencies +install: + - pip install -U tox-travis codecov + +# Command to run tests +script: tox + +after_success: codecov + +deploy: + + - provider: pages + skip-cleanup: true + github-token: "$GITHUB_TOKEN" + keep-history: true + local-dir: docs/_build/html + target-branch: gh-pages + on: + branch: master + python: 3.6 diff --git a/API.md b/API.md new file mode 100644 index 0000000..a2a8793 --- /dev/null +++ b/API.md @@ -0,0 +1,425 @@ +# REST API + +**ATM** comes with the possibility to start a server process that enables interacting with +it via a REST API server that runs over [flask](http://flask.pocoo.org/). + +In this document you will find a briefly explanation how to start it and use it. + + +## Quickstart + +In this section we will briefly show the basic usage of the REST API. + +For more detailed information about all the operations supported by the API, please point your +browser to http://127.0.0.1:5000/ and explore the examples provided by the +[Swagger](https://swagger.io/) interface. + + +### 1. Start the REST API Server + +In order to start a REST API server, after installing ATM open a terminal, activate its +virtualenv, and execute this command: + +```bash +atm start +``` + +This will start **ATM** server as a background service. The REST server will be listening at the +port 5000 of your machine, and if you point your browser at http://127.0.0.1:5000/, you will see +the documentation website that shows information about all the REST operations allowed by the API. + +Optionally, the `--port ` can be added to modify the port which the server listents at: + +```bash +atm start --port 1234 +``` + +If you would like to see the status of the server process you can run: + +```bash +atm status +``` + +An output similar to this one will appear: + +```bash +ATM is running with 1 worker +ATM REST server is listening on http://127.0.0.1:5000 +``` + +In order to stop the server you can run the following command: + +```bash +atm stop +``` + +Notice that `atm start` will start one worker by default. If you would like to launch more than one, +you can do so by adding the argument `--workers `. + +```bash +atm start --workers 4 +``` + +For more detailed options you can run `atm start --help` to obtain a list with the arguments +that are being accepted. + +### 2. Create a Dataset + +Once the server is running, you can register your first dataset using the API. To do so, you need +to send the path to your `CSV` file and the name of your `target_column` in a `POST` request to +`api/datasets`. + +This call will create a simple `dataset` in our database: + +```bash +POST /api/datasets HTTP/1.1 +Content-Type: application/json + +{ + "class_column": "your_target_column", + "train_path": "path/to/your.csv" +} +``` + +Once you have created some datasets, you can see them by sending a `GET` request: + +```bash +GET /api/datasets HTTP/1.1 +``` + +This will return a `json` with all the information about the stored datasets. + +As an example, you can get and register a demo dataset by running the following two commands: + +```bash +atm get_demos +curl -v localhost:5000/api/datasets -H'Content-Type: application/json' \ +-d'{"class_column": "class", "train_path": "demos/pollution_1.csv"}' +``` + +### 3. Trigger a Datarun + +In order to trigger a datarun, once you have created a dataset, you have to send the `dataset_id` +in a `POST` request to `api/run` to trigger the `workers` with the default values. + +```bash +POST /api/datasets HTTP/1.1 +Content-type: application/json + +{ + "dataset_id": id_of_your_dataset +} +``` + +If you have followed the above example and created a `pollution` dataset in the database, you can +run the following `POST` to trigger it's datarun: + +```bash +curl -v localhost:5000/api/run -H'Content-type: application/json' -d'{"dataset_id": 1}' +``` + +**NOTE** atleast one worker should be running in order to process the datarun. + +While running, the workers, will log what they are doing in the file `atm.log`. + +In order to monitor their activity in real time, you can execute this on another terminal: + +```bash +tail -f atm.log +``` + +### 4. Browse the results + +Once the database is populated, you can use the REST API to explore the following 4 models: + +* Datasets +* Dataruns +* Hyperpartitions +* Classifiers + +And these are the operations that can be performed on them: + +#### Get all objects from a model + +In order to get all the objects for a single model, you need to make a `GET` request to +`/api/`. + +The output will be a JSON with 4 entries: + +* `num_results`: The number of results found +* `objects`: A list containing a subdocument for each result +* `page`: The current page +* `total_pages`: The number of pages + +For example, you can get all the datasets using: + +``` +GET /api/datasets HTTP/1.1 +``` + +And the output will be: + +``` +{ + "num_results": 1, + "objects": [ + { + "class_column": "class", + "d_features": 16, + "dataruns": [ + { + "budget": 100, + "budget_type": "classifier", + "dataset_id": 1, + "deadline": null, + "description": "uniform__uniform", + "end_time": "2019-04-11T20:58:11.346733", + "gridding": 0, + "id": 1, + "k_window": 3, + "metric": "f1", + "priority": 1, + "r_minimum": 2, + "score_target": "cv_judgment_metric", + "selector": "uniform", + "start_time": "2019-04-11T20:58:02.514514", + "status": "complete", + "tuner": "uniform" + } + ], + "description": null, + "id": 1, + "k_classes": 2, + "majority": 0.516666667, + "n_examples": 60, + "name": "pollution_1", + "size_kb": 8, + "test_path": null, + "train_path": "/path/to/atm/data/test/pollution_1.csv" + } + ], + "page": 1, + "total_pages": 1 +} +``` + +#### Get a single object by id + +In order to get one particular objects for a model, you need to make a `GET` request to +`/api//`. + +The output will be the document representing the corresponding object. + +For example, you can get the dataset with id 1 using: + +``` +GET /api/datasets/1 HTTP/1.1 +``` + +And the output will be: + +``` +{ + "class_column": "class", + "d_features": 16, + "dataruns": [ + { + "budget": 100, + "budget_type": "classifier", + "dataset_id": 1, + "deadline": null, + "description": "uniform__uniform", + "end_time": "2019-04-11T20:58:11.346733", + "gridding": 0, + "id": 1, + "k_window": 3, + "metric": "f1", + "priority": 1, + "r_minimum": 2, + "score_target": "cv_judgment_metric", + "selector": "uniform", + "start_time": "2019-04-11T20:58:02.514514", + "status": "complete", + "tuner": "uniform" + } + ], + "description": null, + "id": 1, + "k_classes": 2, + "majority": 0.516666667, + "n_examples": 60, + "name": "pollution_1", + "size_kb": 8, + "test_path": null, + "train_path": "/path/to/atm/data/test/pollution_1.csv" +} +``` + +#### Get all the children objects + +In order to get all the childre objects from one parent object, you need to make a +`GET` request to `/api///`. + +The output will be in the same format as if you had requested all the elements from the +children model, but with the results filtered by the parent one. + +So, for example, in order to get all the dataruns that use the dataset with id 1, you can use: + +``` +GET /api/datasets/1/dataruns HTTP/1.1 +``` + +And the output will be (note that some parts have been cut): + +``` +{ + "num_results": 1, + "objects": [ + { + "budget": 100, + "budget_type": "classifier", + "classifiers": [ + { + "cv_judgment_metric": 0.8444444444, + "cv_judgment_metric_stdev": 0.1507184441, + "datarun_id": 1, + "end_time": "2019-04-11T20:58:02.600185", + "error_message": null, + "host": "83.56.245.36", + "hyperparameter_values_64": "gAN9cQAoWAsAAABuX25laWdoYm9yc3EBY251bXB5LmNvcmUubXVsdGlhcnJheQpzY2FsYXIKcQJjbnVtcHkKZHR5cGUKcQNYAgAAAGk4cQRLAEsBh3EFUnEGKEsDWAEAAAA8cQdOTk5K/////0r/////SwB0cQhiQwgPAAAAAAAAAHEJhnEKUnELWAkAAABsZWFmX3NpemVxDGgCaAZDCCsAAAAAAAAAcQ2GcQ5ScQ9YBwAAAHdlaWdodHNxEFgIAAAAZGlzdGFuY2VxEVgJAAAAYWxnb3JpdGhtcRJYCQAAAGJhbGxfdHJlZXETWAYAAABtZXRyaWNxFFgJAAAAbWFuaGF0dGFucRVYBgAAAF9zY2FsZXEWiHUu", + "hyperpartition_id": 23, + "id": 1, + "metrics_location": "metrics/pollution_1-4bc39b14.metric", + "model_location": "models/pollution_1-4bc39b14.model", + "start_time": "2019-04-11T20:58:02.539046", + "status": "complete", + "test_judgment_metric": 0.6250000000 + }, + ... + ], + "dataset": { + "class_column": "class", + "d_features": 16, + "description": null, + "id": 1, + "k_classes": 2, + "majority": 0.516666667, + "n_examples": 60, + "name": "pollution_1", + "size_kb": 8, + "test_path": null, + "train_path": "/path/to/atm/data/test/pollution_1.csv" + }, + "dataset_id": 1, + "deadline": null, + "description": "uniform__uniform", + "end_time": "2019-04-11T20:58:11.346733", + "gridding": 0, + "hyperpartitions": [ + { + "categorical_hyperparameters_64": "gANdcQAoWAcAAABwZW5hbHR5cQFYAgAAAGwxcQKGcQNYDQAAAGZpdF9pbnRlcmNlcHRxBIiGcQVlLg==", + "constant_hyperparameters_64": "gANdcQAoWAwAAABjbGFzc193ZWlnaHRxAVgIAAAAYmFsYW5jZWRxAoZxA1gGAAAAX3NjYWxlcQSIhnEFZS4=", + "datarun_id": 1, + "id": 1, + "method": "logreg", + "status": "incomplete", + "tunable_hyperparameters_64": "gANdcQAoWAEAAABDcQFjYnRiLmh5cGVyX3BhcmFtZXRlcgpGbG9hdEV4cEh5cGVyUGFyYW1ldGVyCnECY2J0Yi5oeXBlcl9wYXJhbWV0ZXIKUGFyYW1UeXBlcwpxA0sFhXEEUnEFXXEGKEc+5Pi1iONo8UdA+GoAAAAAAGWGcQeBcQh9cQkoWAwAAABfcGFyYW1fcmFuZ2VxCmgGWAUAAAByYW5nZXELXXEMKEfAFAAAAAAAAEdAFAAAAAAAAGV1YoZxDVgDAAAAdG9scQ5oAmgFXXEPKEc+5Pi1iONo8UdA+GoAAAAAAGWGcRCBcRF9cRIoaApoD2gLXXETKEfAFAAAAAAAAEdAFAAAAAAAAGV1YoZxFGUu" + }, + ... + ], + "id": 1, + "k_window": 3, + "metric": "f1", + "priority": 1, + "r_minimum": 2, + "score_target": "cv_judgment_metric", + "selector": "uniform", + "start_time": "2019-04-11T20:58:02.514514", + "status": "complete", + "tuner": "uniform" + } + ], + "page": 1, + "total_pages": 1 +} +``` + +## Additional information + +### Start additional process with different pid file + +If you would like to run more workers or you would like to launch a second **ATM** process, you can +do so by specifying a different `PID` file. + +For example: + +```bash +atm start --no-server -w 4 --pid additional_workers.pid +``` + +To check the status of this process we have to run: + +```bash +atm status --pid additional_workers.pid +``` + +This will print an output like this: + +```bash +ATM is running with 4 workers +``` + +### Restart the ATM process + +If you have an **ATM** process running and you would like to restart it and add more workers to it +or maybe change the port on which is running, you can achieve so with the `atm restart`: + +```bash +atm restart +``` + +This command will restart the server with the default values, so if you would like to use other +options you can run `--help` to see the accepted arguments: + +```bash +atm restart --help +``` + +### Stop the ATM process + +As we saw before, by runing the command `atm stop` you will `terminate` the ATM process. However +this command accepts a few arguments in order to control this behaviour: + +* `-t TIMEOUT`, `--timeout TIMEOUT`, time to wait in order to check if the process has been +terminated. + +* `-f`, `--force`, Kill the process if it does not terminate gracefully. +* `--pid PIDFILE`, PID file to use + +### Start the ATM REST API server in foreground + +If you would like to monitorize the server for debugging process, you can do so by runing the +with the following command: + +```bash +atm server +``` + +An output similar to this one should apear in the terminal: + +```bash + * Serving Flask app "api.setup" (lazy loading) + * Environment: production + WARNING: Do not use the development server in a production environment. + Use a production WSGI server instead. + * Debug mode: on + * Running on http://127.0.0.1:5000/ (Press CTRL+C to quit) + * Restarting with stat + * Debugger is active! + * Debugger PIN: 150-127-826 +``` + +For additional arguments run `atm server --help` + +**Note** that this command will not launch any `workers` process. In order to launch a foreground +worker you have to do so by runing `atm worker`. diff --git a/HISTORY.md b/HISTORY.md index 7e26a33..3aed563 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,19 @@ # History +## 0.1.2 (2019-05-07) + +REST API and Cluster Management. + +### New Features + +* REST API Server - Issues [#82](https://github.com/HDI-Project/ATM/issues/82) and + [#132](https://github.com/HDI-Project/ATM/issues/132) by @RogerTangos, @pvk-developer and @csala +* Add Cluster Management commands to start and stop the server and multiple workers + as background processes - Issue [#130](https://github.com/HDI-Project/ATM/issues/130) by + @pvk-developer and @csala +* Add TravisCI and migrate docs to GitHub Pages - Issue [#129](https://github.com/HDI-Project/ATM/issues/129) + by @pvk-developer + ## 0.1.1 (2019-04-02) First Release on PyPi. diff --git a/Makefile b/Makefile index 94a2957..9f972ef 100644 --- a/Makefile +++ b/Makefile @@ -48,9 +48,9 @@ clean-pyc: ## remove Python file artifacts .PHONY: clean-docs clean-docs: ## remove previously built docs - rm -rf docs/build - rm -f docs/atm.rst - rm -f docs/atm.*.rst + rm -rf docs/_build + rm -f docs/api/atm.rst + rm -f docs/api/atm.*.rst rm -f docs/modules.rst $(MAKE) -C docs clean @@ -106,7 +106,7 @@ fix-lint: ## fix lint issues using autoflake, autopep8, and isort .PHONY: test test: ## run tests quickly with the default Python - python -m pytest tests + python -m pytest --cov=atm .PHONY: test-all test-all: ## run tests on every Python version with tox diff --git a/README.md b/README.md index 18464c1..b14c472 100644 --- a/README.md +++ b/README.md @@ -6,8 +6,8 @@ [![CircleCI][circleci-img]][circleci-url] +[![Travis][travis-img]][travis-url] [![Coverage status][codecov-img]][codecov-url] -[![Documentation][rtd-img]][rtd-url] [circleci-img]: https://circleci.com/gh/HDI-Project/ATM.svg?style=shield [circleci-url]: https://circleci.com/gh/HDI-Project/ATM @@ -17,14 +17,12 @@ [pypi-url]: https://pypi.python.org/pypi/atm [codecov-img]: https://codecov.io/gh/HDI-project/ATM/branch/master/graph/badge.svg [codecov-url]: https://codecov.io/gh/HDI-project/ATM -[rtd-img]: https://readthedocs.org/projects/atm/badge/?version=latest -[rtd-url]: http://atm.readthedocs.io/en/latest/ # ATM - Auto Tune Models - Free software: MIT license -- Documentation: http://atm.readthedocs.io/en/latest/ +- Documentation: https://hdi-project.github.io/ATM/ ATM is an open source software library under the @@ -311,6 +309,15 @@ database for its datarun! --aws-config config/aws.yaml \ ``` + +## REST API Server + +**ATM** comes with the possibility to start a server process that enables interacting with +the ModelHub Database via a REST API server that runs over [flask](http://flask.pocoo.org/). + +For more details about how to start and use this REST API please check the [API.md](API.md) document. + + diff --git a/atm/__init__.py b/atm/__init__.py index 6d9914a..1e5023a 100644 --- a/atm/__init__.py +++ b/atm/__init__.py @@ -12,7 +12,7 @@ __author__ = """MIT Data To AI Lab""" __email__ = 'dailabmit@gmail.com' -__version__ = '0.1.1' +__version__ = '0.1.2-dev' # this defines which modules will be imported by "from atm import *" __all__ = ['config', 'classifier', 'constants', 'database', 'enter_data', diff --git a/atm/api/__init__.py b/atm/api/__init__.py new file mode 100644 index 0000000..a838085 --- /dev/null +++ b/atm/api/__init__.py @@ -0,0 +1,45 @@ +from flask import Flask, jsonify, redirect, request +from flask_restless_swagger import SwagAPIManager as APIManager +from flask_sqlalchemy import SQLAlchemy + +from atm.api.preprocessors import DATASET_PREPROCESSORS +from atm.api.utils import auto_abort, make_absolute +from atm.config import RunConfig + + +def create_app(atm, debug=False): + db = atm.db + app = Flask(__name__) + app.config['DEBUG'] = debug + app.config['SQLALCHEMY_DATABASE_URI'] = make_absolute(db.engine.url) + + # Create the Flask-Restless API manager. + manager = APIManager(app, flask_sqlalchemy_db=SQLAlchemy(app)) + + @app.route('/api/run', methods=['POST']) + @auto_abort((KeyError, ValueError)) + def atm_run(): + data = request.json + run_conf = RunConfig(data) + + dataruns = atm.create_dataruns(run_conf) + + response = { + 'status': 200, + 'datarun_ids': [datarun.id for datarun in dataruns] + } + + return jsonify(response) + + @app.route('/') + def swagger(): + return redirect('/static/swagger/swagger-ui/index.html') + + # Create API endpoints, which will be available at /api/ by + # default. Allowed HTTP methods can be specified as well. + manager.create_api(db.Dataset, methods=['GET', 'POST'], preprocessors=DATASET_PREPROCESSORS) + manager.create_api(db.Datarun, methods=['GET']) + manager.create_api(db.Hyperpartition, methods=['GET']) + manager.create_api(db.Classifier, methods=['GET']) + + return app diff --git a/atm/api/preprocessors.py b/atm/api/preprocessors.py new file mode 100644 index 0000000..3937aa4 --- /dev/null +++ b/atm/api/preprocessors.py @@ -0,0 +1,29 @@ +import os + +from atm.api.utils import auto_abort +from atm.encoder import MetaData + + +@auto_abort((KeyError, FileNotFoundError)) +def dataset_post(data): + """Preprocess the Dataset POST data.""" + + train_path = data['train_path'] + name = data.setdefault('name', os.path.basename(train_path)) + data.setdefault('description', name) + meta = MetaData( + data['class_column'], + train_path, + data.get('test_path') + ) + + data['n_examples'] = meta.n_examples + data['k_classes'] = meta.k_classes + data['d_features'] = meta.d_features + data['majority'] = meta.majority + data['size_kb'] = meta.size + + +DATASET_PREPROCESSORS = { + 'POST': [dataset_post] +} diff --git a/atm/api/utils.py b/atm/api/utils.py new file mode 100644 index 0000000..b809a03 --- /dev/null +++ b/atm/api/utils.py @@ -0,0 +1,43 @@ +import logging +import os +import traceback + +import flask + +LOGGER = logging.getLogger(__name__) + + +def make_absolute(url): + if str(url).startswith('sqlite:///'): + url = 'sqlite:///' + os.path.abspath(url.database) + + return url + + +def abort(code, message=None, error=None): + if error is not None: + error = traceback.format_exception_only(type(error), error)[0].strip() + + response = flask.jsonify({ + 'status': code, + 'error': error, + 'message': message + }) + response.status_code = code + flask.abort(response) + + +def auto_abort(exceptions): + def outer(function): + def inner(*args, **kwargs): + try: + return function(*args, **kwargs) + except exceptions as ex: + abort(400, error=ex) + except Exception as ex: + LOGGER.exception('Uncontrolled Exception Caught') + abort(500, error=ex) + + return inner + + return outer diff --git a/atm/classifier.py b/atm/classifier.py index cc25af3..1bee3f9 100644 --- a/atm/classifier.py +++ b/atm/classifier.py @@ -14,7 +14,6 @@ import numpy as np import pandas as pd -from past.utils import old_div from sklearn import decomposition from sklearn.gaussian_process.kernels import ( RBF, ConstantKernel, ExpSineSquared, Matern, RationalQuadratic) @@ -159,7 +158,7 @@ def test_final_model(self, X, y): # time the prediction start_time = time.time() total = time.time() - start_time - self.avg_predict_time = old_div(total, float(len(y))) + self.avg_predict_time = total / float(len(y)) # TODO: this is hacky. See https://github.com/HDI-Project/ATM/issues/48 binary = self.num_classes == 2 diff --git a/atm/cli.py b/atm/cli.py index 76e739b..1d19e3e 100644 --- a/atm/cli.py +++ b/atm/cli.py @@ -2,107 +2,379 @@ import argparse import glob +import logging +import multiprocessing import os import shutil +import time -from atm.config import ( - add_arguments_aws_s3, add_arguments_datarun, add_arguments_logging, add_arguments_sql) -from atm.models import ATM +import psutil +from daemon import DaemonContext +from lockfile.pidlockfile import PIDLockFile +from atm.api import create_app +from atm.config import AWSConfig, DatasetConfig, LogConfig, RunConfig, SQLConfig +from atm.core import ATM -def _end_to_end_test(args): - """End to end test""" +LOGGER = logging.getLogger(__name__) -def _work(args): - atm = ATM(**vars(args)) +def _get_atm(args): + sql_conf = SQLConfig(args) + aws_conf = AWSConfig(args) + log_conf = LogConfig(args) + return ATM(sql_conf, aws_conf, log_conf) + + +def _work(args, wait=False): + """Creates a single worker.""" + atm = _get_atm(args) + atm.work( - datarun_ids=args.dataruns, - choose_randomly=args.choose_randomly, + datarun_ids=getattr(args, 'dataruns', None), + choose_randomly=False, save_files=args.save_files, cloud_mode=args.cloud_mode, - total_time=args.time, - wait=False + total_time=getattr(args, 'total_time', None), + wait=wait ) +def _serve(args): + """Launch the ATM API with the given host / port.""" + atm = _get_atm(args) + app = create_app(atm, getattr(args, 'debug', False)) + app.run(host=args.host, port=args.port) + + +def _get_pid_path(pid): + """Returns abspath of the pid file which is stored on the cwd.""" + pid_path = pid + + if not os.path.isabs(pid_path): + pid_path = os.path.join(os.getcwd(), pid_path) + + return pid_path + + +def _get_atm_process(pid_path): + """Return `psutil.Process` of the `pid` file. If the pidfile is stale it will release it.""" + pid_file = PIDLockFile(pid_path, timeout=1.0) + + if pid_file.is_locked(): + pid = pid_file.read_pid() + + try: + process = psutil.Process(pid) + if process.name() == 'atm': + return process + else: + pid_file.break_lock() + + except psutil.NoSuchProcess: + pid_file.break_lock() + + +def _status(args): + """Check if the current ATM process is runing.""" + + pid_path = _get_pid_path(args.pid) + process = _get_atm_process(pid_path) + + if process: + workers = 0 + addr = None + for child in process.children(): + connections = child.connections() + if connections: + connection = connections[0] + addr = connection.laddr + + else: + workers += 1 + + s = 's' if workers > 1 else '' + print('ATM is running with {} worker{}'.format(workers, s)) + + if addr: + print('ATM REST server is listening on http://{}:{}'.format(addr.ip, addr.port)) + + else: + print('ATM is not runing.') + + +def _start_background(args): + """Launches the server/worker in daemon processes.""" + if args.server: + LOGGER.info('Starting the REST API server') + + process = multiprocessing.Process(target=_serve, args=(args, )) + process.daemon = True + + process.start() + + pool = multiprocessing.Pool(args.workers) + for _ in range(args.workers): + LOGGER.info('Starting background worker') + pool.apply_async(_work, args=(args, True)) + + pool.close() + pool.join() + + +def _start(args): + """Create a new process of ATM pointing the process to a certain `pid` file.""" + pid_path = _get_pid_path(args.pid) + process = _get_atm_process(pid_path) + + if process: + print('ATM is already running!') + + else: + print('Starting ATM') + + if args.foreground: + _start_background(args) + + else: + pidfile = PIDLockFile(pid_path, timeout=1.0) + + with DaemonContext(pidfile=pidfile, working_directory=os.getcwd()): + # Set up default log file if not already set + if not args.logfile: + _logging_setup(args.verbose, 'atm.log') + + _start_background(args) + + +def _stop(args): + """Stop the current running process of ATM.""" + pid_path = _get_pid_path(args.pid) + process = _get_atm_process(pid_path) + + if process: + process.terminate() + + for _ in range(args.timeout): + if process.is_running(): + time.sleep(1) + else: + break + + if process.is_running(): + print('ATM was not able to stop after {} seconds.'.format(args.timeout)) + if args.force: + print('Killing it.') + process.kill() + + else: + print('Use --force to kill it.') + + else: + print('ATM stopped correctly.') + + else: + print('ATM is not running.') + + +def _restart(args): + _stop(args) + time.sleep(1) + + pid_path = _get_pid_path(args.pid) + process = _get_atm_process(pid_path) + + if process: + print('ATM did not stop correctly. Aborting') + else: + _start(args) + + def _enter_data(args): - atm = ATM(**vars(args)) - atm.enter_data() + atm = _get_atm(args) + run_conf = RunConfig(args) + dataset_conf = DatasetConfig(args) + atm.enter_data(dataset_conf, run_conf) -def _make_config(args): - config_templates = os.path.join('config', 'templates') - config_dir = os.path.join(os.path.dirname(__file__), config_templates) - target_dir = os.path.join(os.getcwd(), config_templates) +def _copy_files(pattern, source, target=None): + if isinstance(source, (list, tuple)): + source = os.path.join(*source) + + if target is None: + target = source + + source_dir = os.path.join(os.path.dirname(__file__), source) + target_dir = os.path.join(os.getcwd(), target) + if not os.path.exists(target_dir): os.makedirs(target_dir) - for template in glob.glob(os.path.join(config_dir, '*.yaml')): - target_file = os.path.join(target_dir, os.path.basename(template)) + for source_file in glob.glob(os.path.join(source_dir, pattern)): + target_file = os.path.join(target_dir, os.path.basename(source_file)) print('Generating file {}'.format(target_file)) - shutil.copy(template, target_file) + shutil.copy(source_file, target_file) + + +def _make_config(args): + _copy_files('*.yaml', ('config', 'templates')) -# load other functions from config.py -def _add_common_arguments(parser): - add_arguments_sql(parser) - add_arguments_aws_s3(parser) - add_arguments_logging(parser) +def _get_demos(args): + _copy_files('*.csv', ('data', 'test'), 'demos') def _get_parser(): - parent = argparse.ArgumentParser(add_help=False) + logging_args = argparse.ArgumentParser(add_help=False) + logging_args.add_argument('-v', '--verbose', action='count', default=0) + logging_args.add_argument('-l', '--logfile') - parser = argparse.ArgumentParser(description='ATM Command Line Interface') + parser = argparse.ArgumentParser(description='ATM Command Line Interface', + parents=[logging_args]) subparsers = parser.add_subparsers(title='action', help='Action to perform') parser.set_defaults(action=None) + # Common Arguments + sql_args = SQLConfig.get_parser() + aws_args = AWSConfig.get_parser() + log_args = LogConfig.get_parser() + run_args = RunConfig.get_parser() + dataset_args = DatasetConfig.get_parser() + # Enter Data Parser - enter_data = subparsers.add_parser('enter_data', parents=[parent]) + enter_data_parents = [ + logging_args, + sql_args, + aws_args, + dataset_args, + log_args, + run_args + ] + enter_data = subparsers.add_parser('enter_data', parents=enter_data_parents, + help='Add a Dataset and trigger a Datarun on it.') enter_data.set_defaults(action=_enter_data) - _add_common_arguments(enter_data) - add_arguments_datarun(enter_data) - enter_data.add_argument('--run-per-partition', default=False, action='store_true', - help='if set, generate a new datarun for each hyperpartition') + + # Wroker Args + worker_args = argparse.ArgumentParser(add_help=False) + worker_args.add_argument('--cloud-mode', action='store_true', default=False, + help='Whether to run this worker in cloud mode') + worker_args.add_argument('--no-save', dest='save_files', action='store_false', + help="don't save models and metrics at all") # Worker - worker = subparsers.add_parser('worker', parents=[parent]) + worker_parents = [ + logging_args, + worker_args, + sql_args, + aws_args, + log_args + ] + worker = subparsers.add_parser('worker', parents=worker_parents, + help='Start a single worker in foreground.') worker.set_defaults(action=_work) - _add_common_arguments(worker) - worker.add_argument('--cloud-mode', action='store_true', default=False, - help='Whether to run this worker in cloud mode') - worker.add_argument('--dataruns', help='Only train on dataruns with these ids', nargs='+') - worker.add_argument('--time', help='Number of seconds to run worker', type=int) - worker.add_argument('--choose-randomly', action='store_true', - help='Choose dataruns to work on randomly (default = sequential order)') + worker.add_argument('--total-time', help='Number of seconds to run worker', type=int) + + # Server Args + server_args = argparse.ArgumentParser(add_help=False) + server_args.add_argument('--host', help='IP to listen at') + server_args.add_argument('--port', help='Port to listen at', type=int) + + # Server + server = subparsers.add_parser('server', parents=[logging_args, server_args, sql_args], + help='Start the REST API Server in foreground.') + server.set_defaults(action=_serve) + server.add_argument('--debug', help='Start in debug mode', action='store_true') + # add_arguments_sql(server) + + # Background Args + background_args = argparse.ArgumentParser(add_help=False) + background_args.add_argument('--pid', help='PID file to use.', default='atm.pid') + + # Start Args + start_args = argparse.ArgumentParser(add_help=False) + start_args.add_argument('--foreground', action='store_true', help='Run on foreground') + start_args.add_argument('-w', '--workers', default=1, type=int, help='Number of workers') + start_args.add_argument('--no-server', dest='server', action='store_false', + help='Do not start the REST server') + + # Start + start_parents = [ + logging_args, + worker_args, + server_args, + background_args, + start_args, + sql_args, + aws_args, + log_args + ] + start = subparsers.add_parser('start', parents=start_parents, + help='Start an ATM Local Cluster.') + start.set_defaults(action=_start) - worker.add_argument('--no-save', dest='save_files', default=True, - action='store_const', const=False, - help="don't save models and metrics at all") + # Status + status = subparsers.add_parser('status', parents=[logging_args, background_args]) + status.set_defaults(action=_status) + + # Stop Args + stop_args = argparse.ArgumentParser(add_help=False) + stop_args.add_argument('-t', '--timeout', default=5, type=int, + help='Seconds to wait before killing the process.') + stop_args.add_argument('-f', '--force', action='store_true', + help='Kill the process if it does not terminate gracefully.') + + # Stop + stop = subparsers.add_parser('stop', parents=[logging_args, stop_args, background_args], + help='Stop an ATM Local Cluster.') + stop.set_defaults(action=_stop) + + # restart + restart = subparsers.add_parser('restart', parents=start_parents + [stop_args], + help='Restart an ATM Local Cluster.') + restart.set_defaults(action=_restart) # Make Config - make_config = subparsers.add_parser('make_config', parents=[parent]) + make_config = subparsers.add_parser('make_config', parents=[logging_args], + help='Generate a config templates folder in the cwd.') make_config.set_defaults(action=_make_config) - # End to end test - end_to_end = subparsers.add_parser('end_to_end', parents=[parent]) - end_to_end.set_defaults(action=_end_to_end_test) - end_to_end.add_argument('--processes', help='number of processes to run concurrently', - type=int, default=4) - - end_to_end.add_argument('--total-time', help='Total time for each worker to work in seconds.', - type=int, default=None) + # Get Demos + get_demos = subparsers.add_parser('get_demos', parents=[logging_args], + help='Generate a demos folder with demo CSVs in the cwd.') + get_demos.set_defaults(action=_get_demos) return parser +def _logging_setup(verbosity=1, logfile=None): + logger = logging.getLogger() + log_level = (2 - verbosity) * 10 + fmt = '%(asctime)s - %(process)d - %(levelname)s - %(module)s - %(message)s' + formatter = logging.Formatter(fmt) + logger.setLevel(log_level) + logger.propagate = False + + if logfile: + file_handler = logging.FileHandler(logfile) + file_handler.setLevel(logging.DEBUG) + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + + else: + console_handler = logging.StreamHandler() + console_handler.setLevel(log_level) + console_handler.setFormatter(formatter) + logger.addHandler(console_handler) + + def main(): parser = _get_parser() args = parser.parse_args() + _logging_setup(args.verbose, args.logfile) + if not args.action: parser.print_help() parser.exit() diff --git a/atm/config.py b/atm/config.py index 7a39961..9d79704 100644 --- a/atm/config.py +++ b/atm/config.py @@ -1,19 +1,15 @@ from __future__ import absolute_import, unicode_literals -import logging +import argparse import os import re -import socket -import sys -from argparse import ArgumentError, ArgumentTypeError, RawTextHelpFormatter -from builtins import map, object, str +from builtins import object, str import yaml from atm.constants import ( - BUDGET_TYPES, CUSTOM_CLASS_REGEX, DATA_TEST_PATH, JSON_REGEX, LOG_LEVELS, METHODS, METRICS, - SCORE_TARGETS, SELECTORS, SQL_DIALECTS, TIME_FMT, TUNERS) -from atm.utilities import ensure_directory + BUDGET_TYPES, CUSTOM_CLASS_REGEX, DATA_TEST_PATH, JSON_REGEX, METHODS, METRICS, SCORE_TARGETS, + SELECTORS, SQL_DIALECTS, TIME_FMT, TUNERS) class Config(object): @@ -30,172 +26,141 @@ class Config(object): Subclasses do not need to define __init__ or any other methods. """ - # list of all parameters which may be set on this config object - PARAMETERS = [] - # default values for all required parameters - DEFAULTS = {} + _PREFIX = None + _CONFIG = None + + @classmethod + def _add_prefix(cls, name): + if cls._PREFIX: + return '{}_{}'.format(cls._PREFIX, name) + else: + return name + + @classmethod + def _get_arg(cls, args, name): + arg_name = cls._add_prefix(name) + class_value = getattr(cls, name) + required = False + if isinstance(class_value, dict): + required = 'default' not in class_value + default = class_value.get('default') + elif isinstance(class_value, tuple): + required = False + default = class_value[1] + else: + required = False + default = None + + if required and arg_name not in args: + raise KeyError(arg_name) + + return args.get(arg_name, default) + + def __init__(self, args, path=None): + if isinstance(args, argparse.Namespace): + args = vars(args) + + config_arg = self._CONFIG or self._PREFIX + if not path and config_arg: + path = args.get(config_arg + '_config') + + if path: + with open(path, 'r') as f: + args = yaml.load(f) + + for name, value in vars(self.__class__).items(): + if not name.startswith('_') and not callable(value): + setattr(self, name, self._get_arg(args, name)) + + @classmethod + def get_parser(cls): + parser = argparse.ArgumentParser(add_help=False) + + # make sure the text for these arguments is formatted correctly + # this allows newlines in the help strings + parser.formatter_class = argparse.RawTextHelpFormatter + + if cls._PREFIX: + parser.add_argument('--{}-config'.format(cls._PREFIX), + help='path to yaml {} config file'.format(cls._PREFIX)) + + for name, description in vars(cls).items(): + if not name.startswith('_') and not callable(description): + arg_name = '--' + cls._add_prefix(name).replace('_', '-') + + if isinstance(description, dict): + parser.add_argument(arg_name, **description) + + elif isinstance(description, tuple): + description, default = description + parser.add_argument(arg_name, help=description, default=default) + + else: + parser.add_argument(arg_name, help=description) + + return parser + + def to_dict(self): + return { + name: value + for name, value in vars(self).items() + if not name.startswith('_') and not callable(value) + } + + def __repr__(self): + return '{}({})'.format(self.__class__.__name__, self.to_dict()) - def __init__(self, **kwargs): - for key in self.PARAMETERS: - value = kwargs.get(key) - # Here, if a keyword argument is set to None, it will be overridden - # by the default value. AFAIK, this is the only way to deal with - # keyword args passed in from argparse that weren't set on the - # command line. That means you shouldn't define any PARAMETERS for - # which None is a meaningful value; if you do, make sure None is - # also the default. - if key in self.DEFAULTS and value is None: - value = self.DEFAULTS[key] +class AWSConfig(Config): + """ Stores configuration for AWS S3 connections """ + _PREFIX = 'aws' - setattr(self, key, value) + access_key = 'AWS access key' + secret_key = 'AWS secret key' + s3_bucket = 'AWS S3 bucket to store data' + s3_folder = 'Folder in AWS S3 bucket in which to store data' -class AWSConfig(Config): - """ Stores configuration for AWS S3 and EC2 connections """ - PARAMETERS = [ - # universal config - 'access_key', - 'secret_key', - - # s3 config - 's3_bucket', - 's3_folder', - - # ec2 config - 'ec2_region', - 'ec2_amis', - 'ec2_key_pair', - 'ec2_keyfile', - 'ec2_instance_type', - 'ec2_username', - 'num_instances', - 'num_workers_per_instance' - ] - - DEFAULTS = {} +class DatasetConfig(Config): + """ Stores configuration of a Dataset """ + _CONFIG = 'run' + + train_path = ('Path to raw training data', os.path.join(DATA_TEST_PATH, 'pollution_1.csv')) + test_path = 'Path to raw test data (if applicable)' + data_description = 'Description of dataset' + class_column = ('Name of the class column in the input data', 'class') class SQLConfig(Config): """ Stores configuration for SQL database setup & connection """ - PARAMETERS = [ - 'dialect', - 'database', - 'username', - 'password', - 'host', - 'port', - 'query' - ] - - DEFAULTS = { - 'dialect': 'sqlite', - 'database': 'atm.db', - } - + _PREFIX = 'sql' -class RunConfig(Config): - """ Stores configuration for Dataset and Datarun setup """ - PARAMETERS = [ - # dataset config - 'train_path', - 'test_path', - 'data_description', - 'class_column', - - # datarun config - 'dataset_id', - 'methods', - 'priority', - 'budget_type', - 'budget', - 'deadline', - 'tuner', - 'r_minimum', - 'gridding', - 'selector', - 'k_window', - 'metric', - 'score_target' - ] - - DEFAULTS = { - 'train_path': os.path.join(DATA_TEST_PATH, 'pollution_1.csv'), - 'class_column': 'class', - 'methods': ['logreg', 'dt', 'knn'], - 'priority': 1, - 'budget_type': 'classifier', - 'budget': 100, - 'tuner': 'uniform', - 'selector': 'uniform', - 'r_minimum': 2, - 'k_window': 3, - 'gridding': 0, - 'metric': 'f1', - 'score_target': 'cv', + dialect = { + 'help': 'Dialect of SQL to use', + 'default': 'sqlite', + 'choices': SQL_DIALECTS } + database = ('Name of, or path to, SQL database', 'atm.db') + username = 'Username for SQL database' + password = 'Password for SQL database' + host = 'Hostname for database machine' + port = 'Port used to connect to database' + query = 'Specify extra login details' class LogConfig(Config): - PARAMETERS = [ - 'log_level_stdout', - 'log_level_file', - 'log_dir', - 'model_dir', - 'metric_dir', - 'verbose_metrics', - ] - - DEFAULTS = { - 'log_level_stdout': 'ERROR', - 'log_level_file': 'INFO', - 'log_dir': 'logs', - 'model_dir': 'models', - 'metric_dir': 'metrics', - 'verbose_metrics': False, + model_dir = ('Directory where computed models will be saved', 'models') + metric_dir = ('Directory where model metrics will be saved', 'metrics') + verbose_metrics = { + 'help': ( + 'If set, compute full ROC and PR curves and ' + 'per-label metrics for each classifier' + ), + 'action': 'store_true', + 'default': False } -def initialize_logging(config): - file_level = LOG_LEVELS.get(config.log_level_file.upper(), - logging.CRITICAL) - stdout_level = LOG_LEVELS.get(config.log_level_stdout.upper(), - logging.CRITICAL) - - handlers = [] - if file_level > logging.NOTSET: - fmt = '%(asctime)-15s %(name)s - %(levelname)s %(message)s' - ensure_directory(config.log_dir) - path = os.path.join(config.log_dir, socket.gethostname() + '.txt') - handler = logging.FileHandler(path) - handler.setFormatter(logging.Formatter(fmt)) - handler.setLevel(file_level) - handlers.append(handler) - - if stdout_level > logging.NOTSET: - fmt = '%(message)s' - handler = logging.StreamHandler(sys.stdout) - handler.setFormatter(logging.Formatter(fmt)) - handler.setLevel(stdout_level) - handlers.append(handler) - - if not len(handlers): - handlers.append(logging.NullHandler()) - - for lib in ['atm', 'btb']: - logger = logging.getLogger(lib) - logger.setLevel(min(file_level, stdout_level)) - - for h in logger.handlers: - logger.removeHandler(h) - - for h in handlers: - logger.addHandler(h) - - logger.propagate = False - logger.debug('Logging is active for module %s.' % lib) - - def option_or_path(options, regex=CUSTOM_CLASS_REGEX): def type_check(s): # first, check whether the argument is one of the preconfigured options @@ -209,164 +174,26 @@ def type_check(s): return s # if both of those fail, there's something wrong - raise ArgumentTypeError('%s is not a valid option or path!' % s) + raise argparse.ArgumentTypeError('{} is not a valid option or path!'.format(s)) return type_check -def add_arguments_logging(parser): - """ - Add all argparse arguments needed to parse logging configuration from the - command line. - parser: an argparse.ArgumentParser object - """ - # Config file path - parser.add_argument('--log-config', help='path to yaml logging config file') - - # paths to saved files - parser.add_argument('--model-dir', - help='Directory where computed models will be saved') - parser.add_argument('--metric-dir', - help='Directory where model metrics will be saved') - parser.add_argument('--log-dir', - help='Directory where logs will be saved') - - # hoe much information to log or save - parser.add_argument('--verbose-metrics', action='store_true', - help='If set, compute full ROC and PR curves and ' - 'per-label metrics for each classifier') - - log_levels = list(map(str.lower, list(LOG_LEVELS.keys()))) - parser.add_argument('--log-level-file', choices=log_levels, - help='minimum log level to write to the log file') - # if this is being called from the command line, print more information to - # stdout by default - parser.add_argument('--log-level-stdout', choices=log_levels, - help='minimum log level to write to stdout') - - return parser - - -def add_arguments_aws_s3(parser): - """ - Add all argparse arguments needed to parse AWS S3 configuration from the - command line. This is separate from aws_ec2 because usually only one set of - arguments or the other is needed. - parser: an argparse.ArgumentParser object - """ - # Config file - parser.add_argument('--aws-config', help='path to yaml AWS config file') - - # All of these arguments must start with --aws-, and must correspond to - # keys present in the AWS config example file. - # AWS API access key pair - # try... catch because this might be called after aws_s3 - try: - parser.add_argument('--aws-access-key', help='AWS access key') - parser.add_argument('--aws-secret-key', help='AWS secret key') - except ArgumentError: - pass - - # S3-specific arguments - parser.add_argument('--aws-s3-bucket', help='AWS S3 bucket to store data') - parser.add_argument('--aws-s3-folder', help='Folder in AWS S3 bucket in which to store data') - - return parser - - -def add_arguments_aws_ec2(parser): - """ - Add all argparse arguments needed to parse AWS EC2 configuration from the - command line. This is separate from aws_s3 because usually only one set of - arguments or the other is needed. - parser: an argparse.ArgumentParser object - """ - # Config file - parser.add_argument('--aws-config', help='path to yaml AWS config file') - - # All of these arguments must start with --aws-, and must correspond to - # keys present in the AWS config example file. - # AWS API access key pair - # try... catch because this might be called after aws_s3 - try: - parser.add_argument('--aws-access-key', help='AWS access key') - parser.add_argument('--aws-secret-key', help='AWS secret key') - except ArgumentError: - pass - - # AWS EC2 configurations - parser.add_argument('--num-instances', help='Number of EC2 instances to start') - parser.add_argument('--num-workers-per-instance', help='Number of ATM workers per instances') - parser.add_argument('--ec2-region', help='Region to start instances in') - parser.add_argument('--ec2-ami', help='Name of ATM AMI') - parser.add_argument('--ec2-key-pair', help='AWS key pair to use for EC2 instances') - parser.add_argument('--ec2-keyfile', help='Local path to key file (must match ec2-key-pair)') - parser.add_argument('--ec2-instance-type', help='Type of EC2 instance to start') - parser.add_argument('--ec2-username', help='Username to log into EC2 instance') - - return parser - - -def add_arguments_sql(parser): - """ - Add all argparse arguments needed to parse configuration for the ModelHub - SQL database from the command line. - - parser: an argparse.ArgumentParser object - """ - # Config file - parser.add_argument('--sql-config', help='path to yaml SQL config file') - - # All of these arguments must start with --sql-, and must correspond to - # keys present in the SQL config example file. - parser.add_argument('--sql-dialect', choices=SQL_DIALECTS, - help='Dialect of SQL to use') - parser.add_argument('--sql-database', - help='Name of, or path to, SQL database') - parser.add_argument('--sql-username', help='Username for SQL database') - parser.add_argument('--sql-password', help='Password for SQL database') - parser.add_argument('--sql-host', help='Hostname for database machine') - parser.add_argument('--sql-port', help='Port used to connect to database') - parser.add_argument('--sql-query', help='Specify extra login details') - - return parser - - -def add_arguments_datarun(parser): - """ - Add all argparse arguments needed to parse dataset and datarun configuration - from the command line. - - parser: an argparse.ArgumentParser object - """ - # make sure the text for these arguments is formatted correctly - # this allows newlines in the help strings - parser.formatter_class = RawTextHelpFormatter - - # Config file - parser.add_argument('--run-config', help='path to yaml datarun config file') +class RunConfig(Config): + """Stores configuration for Dataset and Datarun setup.""" + _CONFIG = 'run' - # Dataset Arguments ##################################################### - # ########################################################################## - parser.add_argument('--dataset-id', type=int, - help="ID of dataset, if it's already in the database") + dataset_id = { + 'help': 'ID of dataset, if it is already in the database', + 'type': int + } - # These are only relevant if dataset_id is not provided - parser.add_argument('--train-path', help='Path to raw training data') - parser.add_argument('--test-path', help='Path to raw test data (if applicable)') - parser.add_argument('--data-description', help='Description of dataset') - parser.add_argument('--class-column', help='Name of the class column in the input data') + run_per_partition = { + 'help': 'if true, generate a new datarun for each hyperpartition', + 'default': False, + 'action': 'store_true', + } - # Datarun Arguments ##################################################### - # ########################################################################## - # Notes: - # - Support vector machines (svm) can take a long time to train. It's not an - # error, it's just part of what happens when the method happens to explore - # a crappy set of parameters on a powerful algo like this. - # - Stochastic gradient descent (sgd) can sometimes fail on certain - # parameter settings as well. Don't worry, they train SUPER fast, and the - # worker.py will simply log the error and continue. - # # Method options: # logreg - logistic regression # svm - support vector machine @@ -381,23 +208,45 @@ def add_arguments_datarun(parser): # pa - passive aggressive # knn - K nearest neighbors # mlp - multi-layer perceptron - parser.add_argument('--methods', nargs='+', - type=option_or_path(METHODS, JSON_REGEX), - help='Method or list of methods to use for ' - 'classification. Each method can either be one of the ' - 'pre-defined method codes listed below or a path to a ' - 'JSON file defining a custom method.' - '\n\nOptions: [%s]' % ', '.join(str(s) for s in METHODS)) - parser.add_argument('--priority', type=int, - help='Priority of the datarun (higher = more important') - parser.add_argument('--budget-type', choices=BUDGET_TYPES, - help='Type of budget to use') - parser.add_argument('--budget', type=int, - help='Value of the budget, either in classifiers or minutes') - parser.add_argument('--deadline', - help='Deadline for datarun completion. If provided, this ' - 'overrides the configured walltime budget.\nFormat: {}'.format( - TIME_FMT.replace('%', '%%'))) + # + # Notes: + # - Support vector machines (svm) can take a long time to train. It's not an + # error, it's just part of what happens when the method happens to explore + # a crappy set of parameters on a powerful algo like this. + # - Stochastic gradient descent (sgd) can sometimes fail on certain + # parameter settings as well. Don't worry, they train SUPER fast, and the + # worker.py will simply log the error and continue. + methods = { + 'help': ( + 'Method or list of methods to use for ' + 'classification. Each method can either be one of the ' + 'pre-defined method codes listed below or a path to a ' + 'JSON file defining a custom method.\n\nOptions: [{}]' + ).format(', '.join(str(s) for s in METHODS)), + 'default': ['logreg', 'dt', 'knn'], + 'type': option_or_path(METHODS, JSON_REGEX), + 'nargs': '+' + } + + priority = { + 'help': 'Priority of the datarun (higher = more important', + 'default': 1, + 'type': int + } + budget_type = { + 'help': 'Type of budget to use', + 'default': 'classifier', + 'choices': BUDGET_TYPES, + } + budget = { + 'help': 'Value of the budget, either in classifiers or minutes', + 'default': 100, + 'type': int, + } + deadline = ( + 'Deadline for datarun completion. If provided, this ' + 'overrides the configured walltime budget.\nFormat: {}' + ).format(TIME_FMT.replace('%', '%%')) # Which field to use to judge performance, for the sake of AutoML # options: @@ -414,24 +263,35 @@ def add_arguments_datarun(parser): # # f1 and roc_auc may be appended with _micro or _macro to use with # multiclass problems. - parser.add_argument('--metric', choices=METRICS, - help='Metric by which ATM should evaluate classifiers. ' - 'The metric function specified here will be used to ' - 'compute the "judgment metric" for each classifier.') + metric = { + 'help': ( + 'Metric by which ATM should evaluate classifiers. ' + 'The metric function specified here will be used to ' + 'compute the "judgment metric" for each classifier.' + ), + 'default': 'f1', + 'choices': METRICS, + } # Which data to use for computing judgment score # cv - cross-validated performance on training data # test - performance on test data # mu_sigma - lower confidence bound on cv score - parser.add_argument('--score-target', choices=SCORE_TARGETS, - help='Determines which judgment metric will be used to ' - 'search the hyperparameter space. "cv" will use the mean ' - 'cross-validated performance, "test" will use the ' - 'performance on a test dataset, and "mu_sigma" will use ' - 'the lower confidence bound on the CV performance.') + score_target = { + 'help': ( + 'Determines which judgment metric will be used to ' + 'search the hyperparameter space. "cv" will use the mean ' + 'cross-validated performance, "test" will use the ' + 'performance on a test dataset, and "mu_sigma" will use ' + 'the lower confidence bound on the CV performance.' + ), + 'default': 'cv', + 'choices': SCORE_TARGETS + } # AutoML Arguments ###################################################### # ########################################################################## + # hyperparameter selection strategy # How should ATM sample hyperparameters from a given hyperpartition? # uniform - pick randomly! (baseline) @@ -440,11 +300,15 @@ def add_arguments_datarun(parser): # gp_eivel - Gaussian Process expected improvement, with randomness added # in based on velocity of improvement # path to custom tuner, defined in python - parser.add_argument('--tuner', type=option_or_path(TUNERS), - help='Type of BTB tuner to use. Can either be one of ' - 'the pre-configured tuners listed below or a path to a ' - 'custom tuner in the form "/path/to/tuner.py:ClassName".' - '\n\nOptions: [%s]' % ', '.join(str(s) for s in TUNERS)) + tuner = { + 'help': ( + 'Type of BTB tuner to use. Can either be one of the pre-configured ' + 'tuners listed below or a path to a custom tuner in the form ' + '"/path/to/tuner.py:ClassName".\n\nOptions: [{}]' + ).format(', '.join(str(s) for s in TUNERS)), + 'default': 'uniform', + 'type': option_or_path(TUNERS) + } # How should ATM select a particular hyperpartition from the set of all # possible hyperpartitions? @@ -459,11 +323,15 @@ def add_arguments_datarun(parser): # hieralg - hierarchical MAB: choose a classifier first, then choose # a partition # path to custom selector, defined in python - parser.add_argument('--selector', type=option_or_path(SELECTORS), - help='Type of BTB selector to use. Can either be one of ' - 'the pre-configured selectors listed below or a path to a ' - 'custom tuner in the form "/path/to/selector.py:ClassName".' - '\n\nOptions: [%s]' % ', '.join(str(s) for s in SELECTORS)) + selector = { + 'help': ( + 'Type of BTB selector to use. Can either be one of the pre-configured ' + 'selectors listed below or a path to a custom tuner in the form ' + '"/path/to/selector.py:ClassName".\n\nOptions: [{}]' + ).format(', '.join(str(s) for s in SELECTORS)), + 'default': 'uniform', + 'type': option_or_path(SELECTORS) + } # r_minimum is the number of random runs performed in each hyperpartition before # allowing bayesian opt to select parameters. Consult the thesis to @@ -473,14 +341,20 @@ def add_arguments_datarun(parser): # # train using sample criteria # else # # train using uniform (baseline) - parser.add_argument('--r-minimum', type=int, - help='number of random runs to perform before tuning can occur') + r_minimum = { + 'help': 'number of random runs to perform before tuning can occur', + 'default': 2, + 'type': int + } # k is number that xxx-k methods use. It is similar to r_minimum, except it is # called k_window and determines how much "history" ATM considers for certain # partition selection logics. - parser.add_argument('--k-window', type=int, - help='number of previous scores considered by -k selector methods') + k_window = { + 'help': 'number of previous scores considered by -k selector methods', + 'default': 3, + 'type': int + } # gridding determines whether or not sample selection will happen on a grid. # If any positive integer, a grid with `gridding` points on each axis is @@ -488,77 +362,8 @@ def add_arguments_datarun(parser): # space. If 0 (or blank), hyperparameters are sampled from continuous # space, and there is no limit to the number of hyperparameter vectors that # may be tried. - parser.add_argument('--gridding', type=int, - help='gridding factor (0: no gridding)') - - return parser - - -def load_config(sql_path=None, run_path=None, aws_path=None, log_path=None, **kwargs): - """ - Load config objects from yaml files and command line arguments. Command line - args override yaml files where applicable. - - Args: - sql_path: path to .yaml file with SQL configuration - run_path: path to .yaml file with Dataset and Datarun configuration - aws_path: path to .yaml file with AWS configuration - log_path: path to .yaml file with logging configuration - **kwargs: miscellaneous arguments specifying individual configuration - parameters. Any kwargs beginning with sql_ are SQL config - arguments, any beginning with aws_ are AWS config. - - Returns: sql_conf, run_conf, aws_conf, log_conf - """ - sql_args = {} - run_args = {} - aws_args = {} - log_args = {} - - # kwargs are most likely generated by argparse. - # Any unspecified argparse arguments will be None, so ignore those. We only - # care about arguments explicitly specified by the user. - kwargs = {k: v for k, v in list(kwargs.items()) if v is not None} - - # check the keyword args for config paths - sql_path = sql_path or kwargs.get('sql_config') - run_path = run_path or kwargs.get('run_config') - aws_path = aws_path or kwargs.get('aws_config') - log_path = log_path or kwargs.get('log_config') - - # load any yaml config files for which paths were provided - if sql_path: - with open(sql_path) as f: - sql_args = yaml.load(f) - - if run_path: - with open(run_path) as f: - run_args = yaml.load(f) - - if aws_path: - with open(aws_path) as f: - aws_args = yaml.load(f) - - if log_path: - with open(log_path) as f: - log_args = yaml.load(f) - - # Use keyword args to override yaml config values - sql_args.update({k.replace('sql_', ''): v for k, v in list(kwargs.items()) - if 'sql_' in k}) - aws_args.update({k.replace('aws_', ''): v for k, v in list(kwargs.items()) - if 'aws_' in k}) - run_args.update({k: v for k, v in list(kwargs.items()) if k in - RunConfig.PARAMETERS}) - log_args.update({k: v for k, v in list(kwargs.items()) if k in - LogConfig.PARAMETERS}) - - # It's ok if there are some extra arguments that get passed in here; only - # kwargs that correspond to real config values will be stored on the config - # objects. - sql_conf = SQLConfig(**sql_args) - aws_conf = AWSConfig(**aws_args) - run_conf = RunConfig(**run_args) - log_conf = LogConfig(**log_args) - - return sql_conf, run_conf, aws_conf, log_conf + gridding = { + 'help': 'gridding factor (0: no gridding)', + 'default': 0, + 'type': int + } diff --git a/atm/models.py b/atm/core.py similarity index 59% rename from atm/models.py rename to atm/core.py index 399919c..b7946ea 100644 --- a/atm/models.py +++ b/atm/core.py @@ -1,45 +1,39 @@ +# -*- coding: utf-8 -*- + +"""Core ATM module. + +This module contains the ATM class, which is the one responsible for +executing and orchestrating the main ATM functionalities. +""" + from __future__ import absolute_import, division, unicode_literals import logging import os import random import time -from builtins import map, object +from builtins import object from datetime import datetime, timedelta from operator import attrgetter -from past.utils import old_div - -from atm.config import initialize_logging, load_config -from atm.constants import PROJECT_ROOT, TIME_FMT, PartitionStatus +from atm.constants import TIME_FMT, PartitionStatus from atm.database import Database from atm.encoder import MetaData from atm.method import Method from atm.utilities import download_data, get_public_ip from atm.worker import ClassifierError, Worker -# load the library-wide logger -logger = logging.getLogger('atm') +LOGGER = logging.getLogger(__name__) class ATM(object): - """ - Thiss class is code API instance that allows you to use ATM in your python code. - """ LOOP_WAIT = 1 - def __init__(self, **kwargs): - - if kwargs.get('log_config') is None: - kwargs['log_config'] = os.path.join(PROJECT_ROOT, - 'config/templates/log-script.yaml') - - self.sql_conf, self.run_conf, self.aws_conf, self.log_conf = load_config(**kwargs) - - self.db = Database(**vars(self.sql_conf)) - - initialize_logging(self.log_conf) + def __init__(self, sql_conf, aws_conf, log_conf): + self.db = Database(**sql_conf.to_dict()) + self.aws_conf = aws_conf + self.log_conf = log_conf def work(self, datarun_ids=None, save_files=False, choose_randomly=True, cloud_mode=False, total_time=None, wait=True): @@ -71,13 +65,13 @@ def work(self, datarun_ids=None, save_files=False, choose_randomly=True, dataruns = self.db.get_dataruns(include_ids=datarun_ids, ignore_complete=True) if not dataruns: if wait: - logger.warning('No dataruns found. Sleeping %d seconds and trying again.', - ATM.LOOP_WAIT) + LOGGER.debug('No dataruns found. Sleeping %d seconds and trying again.', + ATM.LOOP_WAIT) time.sleep(ATM.LOOP_WAIT) continue else: - logger.warning('No dataruns found. Exiting.') + LOGGER.info('No dataruns found. Exiting.') break max_priority = max([datarun.priority for datarun in dataruns]) @@ -92,7 +86,7 @@ def work(self, datarun_ids=None, save_files=False, choose_randomly=True, # say we've started working on this datarun, if we haven't already self.db.mark_datarun_running(run.id) - logger.info('Computing on datarun %d' % run.id) + LOGGER.info('Computing on datarun %d' % run.id) # actual work happens here worker = Worker(self.db, run, save_files=save_files, cloud_mode=cloud_mode, aws_config=self.aws_conf, @@ -103,21 +97,21 @@ def work(self, datarun_ids=None, save_files=False, choose_randomly=True, except ClassifierError: # the exception has already been handled; just wait a sec so we # don't go out of control reporting errors - logger.warning('Something went wrong. Sleeping %d seconds.', ATM.LOOP_WAIT) + LOGGER.error('Something went wrong. Sleeping %d seconds.', ATM.LOOP_WAIT) time.sleep(ATM.LOOP_WAIT) elapsed_time = (datetime.now() - start_time).total_seconds() if total_time is not None and elapsed_time >= total_time: - logger.warning('Total run time for worker exceeded; exiting.') + LOGGER.info('Total run time for worker exceeded; exiting.') break - def create_dataset(self): + def create_dataset(self, dataset_conf): """ Create a dataset and add it to the ModelHub database. """ # download data to the local filesystem to extract metadata - train_local, test_local = download_data(self.run_conf.train_path, - self.run_conf.test_path, + train_local, test_local = download_data(dataset_conf.train_path, + dataset_conf.test_path, self.aws_conf) # create the name of the dataset from the path to the data @@ -125,22 +119,22 @@ def create_dataset(self): name = name.replace("_train.csv", "").replace(".csv", "") # process the data into the form ATM needs and save it to disk - meta = MetaData(self.run_conf.class_column, train_local, test_local) + meta = MetaData(dataset_conf.class_column, train_local, test_local) # enter dataset into database dataset = self.db.create_dataset(name=name, - description=self.run_conf.data_description, - train_path=self.run_conf.train_path, - test_path=self.run_conf.test_path, - class_column=self.run_conf.class_column, + description=dataset_conf.data_description, + train_path=dataset_conf.train_path, + test_path=dataset_conf.test_path, + class_column=dataset_conf.class_column, n_examples=meta.n_examples, k_classes=meta.k_classes, d_features=meta.d_features, majority=meta.majority, - size_kb=old_div(meta.size, 1000)) + size_kb=meta.size) return dataset - def create_datarun(self, dataset): + def create_datarun(self, dataset, run_conf): """ Given a config, creates a set of dataruns for the config and enters them into the database. Returns the ID of the created datarun. @@ -148,72 +142,67 @@ def create_datarun(self, dataset): dataset: Dataset SQLAlchemy ORM object """ # describe the datarun by its tuner and selector - run_description = '__'.join([self.run_conf.tuner, self.run_conf.selector]) + run_description = '__'.join([run_conf.tuner, run_conf.selector]) # set the deadline, if applicable - deadline = self.run_conf.deadline + deadline = run_conf.deadline if deadline: deadline = datetime.strptime(deadline, TIME_FMT) # this overrides the otherwise configured budget_type # TODO: why not walltime and classifiers budget simultaneously? - self.run_conf.budget_type = 'walltime' - elif self.run_conf.budget_type == 'walltime': - deadline = datetime.now() + timedelta(minutes=self.run_conf.budget) + run_conf.budget_type = 'walltime' + elif run_conf.budget_type == 'walltime': + deadline = datetime.now() + timedelta(minutes=run_conf.budget) - target = self.run_conf.score_target + '_judgment_metric' + target = run_conf.score_target + '_judgment_metric' datarun = self.db.create_datarun(dataset_id=dataset.id, description=run_description, - tuner=self.run_conf.tuner, - selector=self.run_conf.selector, - gridding=self.run_conf.gridding, - priority=self.run_conf.priority, - budget_type=self.run_conf.budget_type, - budget=self.run_conf.budget, + tuner=run_conf.tuner, + selector=run_conf.selector, + gridding=run_conf.gridding, + priority=run_conf.priority, + budget_type=run_conf.budget_type, + budget=run_conf.budget, deadline=deadline, - metric=self.run_conf.metric, + metric=run_conf.metric, score_target=target, - k_window=self.run_conf.k_window, - r_minimum=self.run_conf.r_minimum) + k_window=run_conf.k_window, + r_minimum=run_conf.r_minimum) return datarun - def enter_data(self, run_per_partition=False): + def create_dataruns(self, run_conf): """ Generate a datarun, including a dataset if necessary. Returns: ID of the generated datarun """ - # connect to the database - - # if the user has provided a dataset id, use that. Otherwise, create a new - # dataset based on the arguments we were passed. - if self.run_conf.dataset_id is None: - dataset = self.create_dataset() - self.run_conf.dataset_id = dataset.id - else: - dataset = self.db.get_dataset(self.run_conf.dataset_id) + dataset = self.db.get_dataset(run_conf.dataset_id) + if not dataset: + raise ValueError('Invalid Dataset ID: {}'.format(run_conf.dataset_id)) method_parts = {} - for m in self.run_conf.methods: + for m in run_conf.methods: # enumerate all combinations of categorical variables for this method method = Method(m) method_parts[m] = method.get_hyperpartitions() - logger.info('method %s has %d hyperpartitions' % + LOGGER.info('method %s has %d hyperpartitions' % (m, len(method_parts[m]))) # create hyperpartitions and datarun(s) - run_ids = [] - if not run_per_partition: - logger.debug('saving datarun...') - datarun = self.create_datarun(dataset) + dataruns = [] + if not run_conf.run_per_partition: + LOGGER.debug('saving datarun...') + datarun = self.create_datarun(dataset, run_conf) + dataruns.append(datarun) - logger.debug('saving hyperpartions...') + LOGGER.debug('saving hyperpartions...') for method, parts in list(method_parts.items()): for part in parts: # if necessary, create a new datarun for each hyperpartition. # This setting is useful for debugging. - if run_per_partition: - datarun = self.create_datarun(dataset) - run_ids.append(datarun.id) + if run_conf.run_per_partition: + datarun = self.create_datarun(dataset, run_conf) + dataruns.append(datarun) # create a new hyperpartition in the database self.db.create_hyperpartition(datarun_id=datarun.id, @@ -223,19 +212,35 @@ def enter_data(self, run_per_partition=False): categoricals=part.categoricals, status=PartitionStatus.INCOMPLETE) - logger.info('Data entry complete. Summary:') - logger.info('\tDataset ID: %d', dataset.id) - logger.info('\tTraining data: %s', dataset.train_path) - logger.info('\tTest data: %s', (dataset.test_path or 'None')) + LOGGER.info('Dataruns created. Summary:') + LOGGER.info('\tDataset ID: %d', dataset.id) + LOGGER.info('\tTraining data: %s', dataset.train_path) + LOGGER.info('\tTest data: %s', (dataset.test_path or 'None')) - if run_per_partition: - logger.info('\tDatarun IDs: %s', ', '.join(map(str, run_ids))) + datarun = dataruns[0] + if run_conf.run_per_partition: + LOGGER.info('\tDatarun IDs: %s', ', '.join(str(datarun.id) for datarun in dataruns)) else: - logger.info('\tDatarun ID: %d', datarun.id) + LOGGER.info('\tDatarun ID: %d', datarun.id) + + LOGGER.info('\tHyperpartition selection strategy: %s', datarun.selector) + LOGGER.info('\tParameter tuning strategy: %s', datarun.tuner) + LOGGER.info('\tBudget: %d (%s)', datarun.budget, datarun.budget_type) + + return dataruns + + def enter_data(self, dataset_conf, run_conf): + """ + Generate a datarun, including a dataset if necessary. - logger.info('\tHyperpartition selection strategy: %s', datarun.selector) - logger.info('\tParameter tuning strategy: %s', datarun.tuner) - logger.info('\tBudget: %d (%s)', datarun.budget, datarun.budget_type) + Returns: ID of the generated datarun + """ + # if the user has provided a dataset id, use that. Otherwise, create a new + # dataset based on the arguments we were passed. + if run_conf.dataset_id is None: + dataset = self.create_dataset(dataset_conf) + run_conf.dataset_id = dataset.id - return run_ids or datarun.id + dataruns = self.create_dataruns(run_conf) + return dataruns[0] if not run_conf.run_per_partition else dataruns diff --git a/atm/encoder.py b/atm/encoder.py index 2e4b714..b4ec13a 100644 --- a/atm/encoder.py +++ b/atm/encoder.py @@ -4,7 +4,6 @@ import numpy as np import pandas as pd -from past.utils import old_div from sklearn.preprocessing import LabelEncoder, OneHotEncoder @@ -26,13 +25,14 @@ def __init__(self, class_column, train_path, test_path=None): for c in data.columns: if data[c].dtype == 'object': total_features += len(np.unique(data[c])) - 1 - majority_percentage = old_div(float(max(counts)), float(sum(counts))) + + majority_percentage = float(max(counts)) / float(sum(counts)) self.n_examples = data.shape[0] self.d_features = total_features self.k_classes = len(np.unique(data[class_column])) self.majority = majority_percentage - self.size = np.array(data).nbytes + self.size = int(np.array(data).nbytes / 1000) class DataEncoder(object): diff --git a/atm/enter_data.py b/atm/enter_data.py deleted file mode 100644 index af5bb10..0000000 --- a/atm/enter_data.py +++ /dev/null @@ -1,161 +0,0 @@ -from __future__ import absolute_import, division, unicode_literals - -import logging -import os -from builtins import map -from datetime import datetime, timedelta - -from past.utils import old_div - -from atm.constants import TIME_FMT, PartitionStatus -from atm.database import Database -from atm.encoder import MetaData -from atm.method import Method -from atm.utilities import download_data - -# load the library-wide logger -logger = logging.getLogger('atm') - - -def create_dataset(db, run_config, aws_config=None): - """ - Create a dataset and add it to the ModelHub database. - - db: initialized Database object - run_config: RunConfig object describing the dataset to create - aws_config: optional. AWS credentials for downloading data from S3. - """ - # download data to the local filesystem to extract metadata - train_local, test_local = download_data(run_config.train_path, - run_config.test_path, - aws_config) - - # create the name of the dataset from the path to the data - name = os.path.basename(train_local) - name = name.replace("_train.csv", "").replace(".csv", "") - - # process the data into the form ATM needs and save it to disk - meta = MetaData(run_config.class_column, train_local, test_local) - - # enter dataset into database - dataset = db.create_dataset(name=name, - description=run_config.data_description, - train_path=run_config.train_path, - test_path=run_config.test_path, - class_column=run_config.class_column, - n_examples=meta.n_examples, - k_classes=meta.k_classes, - d_features=meta.d_features, - majority=meta.majority, - size_kb=old_div(meta.size, 1000)) - return dataset - - -def create_datarun(db, dataset, run_config): - """ - Given a config, creates a set of dataruns for the config and enters them into - the database. Returns the ID of the created datarun. - - db: initialized Database object - dataset: Dataset SQLAlchemy ORM object - run_config: RunConfig object describing the datarun to create - """ - # describe the datarun by its tuner and selector - run_description = '__'.join([run_config.tuner, run_config.selector]) - - # set the deadline, if applicable - deadline = run_config.deadline - if deadline: - deadline = datetime.strptime(deadline, TIME_FMT) - # this overrides the otherwise configured budget_type - # TODO: why not walltime and classifiers budget simultaneously? - run_config.budget_type = 'walltime' - elif run_config.budget_type == 'walltime': - deadline = datetime.now() + timedelta(minutes=run_config.budget) - - target = run_config.score_target + '_judgment_metric' - datarun = db.create_datarun(dataset_id=dataset.id, - description=run_description, - tuner=run_config.tuner, - selector=run_config.selector, - gridding=run_config.gridding, - priority=run_config.priority, - budget_type=run_config.budget_type, - budget=run_config.budget, - deadline=deadline, - metric=run_config.metric, - score_target=target, - k_window=run_config.k_window, - r_minimum=run_config.r_minimum) - return datarun - - -def enter_data(sql_config, run_config, aws_config=None, - run_per_partition=False): - """ - Generate a datarun, including a dataset if necessary. - - sql_config: Object with all attributes necessary to initialize a Database. - run_config: all attributes necessary to initialize a Datarun, including - Dataset info if the dataset has not already been created. - aws_config: all attributes necessary to connect to an S3 bucket. - - Returns: ID of the generated datarun - """ - # connect to the database - db = Database(sql_config.dialect, sql_config.database, sql_config.username, - sql_config.password, sql_config.host, sql_config.port, - sql_config.query) - - # if the user has provided a dataset id, use that. Otherwise, create a new - # dataset based on the arguments we were passed. - if run_config.dataset_id is None: - dataset = create_dataset(db, run_config, aws_config=aws_config) - run_config.dataset_id = dataset.id - else: - dataset = db.get_dataset(run_config.dataset_id) - - method_parts = {} - for m in run_config.methods: - # enumerate all combinations of categorical variables for this method - method = Method(m) - method_parts[m] = method.get_hyperpartitions() - logger.info('method %s has %d hyperpartitions' % - (m, len(method_parts[m]))) - - # create hyperpartitions and datarun(s) - run_ids = [] - if not run_per_partition: - logger.debug('saving datarun...') - datarun = create_datarun(db, dataset, run_config) - - logger.debug('saving hyperpartions...') - for method, parts in list(method_parts.items()): - for part in parts: - # if necessary, create a new datarun for each hyperpartition. - # This setting is useful for debugging. - if run_per_partition: - datarun = create_datarun(db, dataset, run_config) - run_ids.append(datarun.id) - - # create a new hyperpartition in the database - db.create_hyperpartition(datarun_id=datarun.id, - method=method, - tunables=part.tunables, - constants=part.constants, - categoricals=part.categoricals, - status=PartitionStatus.INCOMPLETE) - - logger.info('Data entry complete. Summary:') - logger.info('\tDataset ID: %d' % dataset.id) - logger.info('\tTraining data: %s' % dataset.train_path) - logger.info('\tTest data: %s' % (dataset.test_path or 'None')) - if run_per_partition: - logger.info('\tDatarun IDs: %s' % ', '.join(map(str, run_ids))) - else: - logger.info('\tDatarun ID: %d' % datarun.id) - logger.info('\tHyperpartition selection strategy: %s' % datarun.selector) - logger.info('\tParameter tuning strategy: %s' % datarun.tuner) - logger.info('\tBudget: %d (%s)' % (datarun.budget, datarun.budget_type)) - - return run_ids or datarun.id diff --git a/atm/metrics.py b/atm/metrics.py index e06e0da..039587e 100644 --- a/atm/metrics.py +++ b/atm/metrics.py @@ -4,7 +4,6 @@ import numpy as np import pandas as pd -from past.utils import old_div from sklearn.metrics import ( accuracy_score, average_precision_score, cohen_kappa_score, f1_score, matthews_corrcoef, precision_recall_curve, roc_auc_score, roc_curve) @@ -37,7 +36,7 @@ def rank_n_accuracy(y_true, y_prob_mat, n=0.33): if y_true[i] in rankings[i, :]: correct_sample_count += 1 - return old_div(correct_sample_count, num_samples) + return int(correct_sample_count / num_samples) def get_per_class_matrix(y, classes=None): diff --git a/atm/worker.py b/atm/worker.py index bd44003..4df3c14 100644 --- a/atm/worker.py +++ b/atm/worker.py @@ -55,7 +55,7 @@ def __init__(self, database, datarun, save_files=True, cloud_mode=False, self.aws_config = aws_config self.public_ip = public_ip - log_config = log_config or LogConfig() + log_config = log_config or LogConfig({}) self.model_dir = log_config.model_dir self.metric_dir = log_config.metric_dir self.verbose_metrics = log_config.verbose_metrics @@ -288,6 +288,7 @@ def save_classifier_cloud(self, local_model_path, local_metric_path): local_model_path: path to serialized model in the local file system local_metric_path: path to serialized metrics in the local file system """ + # TODO: This does not work conn = S3Connection(self.aws_config.access_key, self.aws_config.secret_key) bucket = conn.get_bucket(self.aws_config.s3_bucket) diff --git a/docs/Makefile b/docs/Makefile index 330f546..4e63b04 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -1,225 +1,20 @@ -# Makefile for Sphinx documentation +# Minimal makefile for Sphinx documentation # # You can set these variables from the command line. SPHINXOPTS = -SPHINXBUILD = sphinx-build -PAPER = -BUILDDIR = ./build +SPHINXBUILD = python -msphinx +SPHINXPROJ = stegdetect +SOURCEDIR = . +BUILDDIR = _build -# Internal variables. -PAPEROPT_a4 = -D latex_paper_size=a4 -PAPEROPT_letter = -D latex_paper_size=letter -ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) source -# the i18n builder cannot share the environment and doctrees with the others -I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) source - -.PHONY: help +# Put it first so that "make" without argument is like "make help". help: - @echo "Please use \`make ' where is one of" - @echo " html to make standalone HTML files" - @echo " dirhtml to make HTML files named index.html in directories" - @echo " singlehtml to make a single large HTML file" - @echo " pickle to make pickle files" - @echo " json to make JSON files" - @echo " htmlhelp to make HTML files and a HTML help project" - @echo " qthelp to make HTML files and a qthelp project" - @echo " applehelp to make an Apple Help Book" - @echo " devhelp to make HTML files and a Devhelp project" - @echo " epub to make an epub" - @echo " epub3 to make an epub3" - @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" - @echo " latexpdf to make LaTeX files and run them through pdflatex" - @echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx" - @echo " text to make text files" - @echo " man to make manual pages" - @echo " texinfo to make Texinfo files" - @echo " info to make Texinfo files and run them through makeinfo" - @echo " gettext to make PO message catalogs" - @echo " changes to make an overview of all changed/added/deprecated items" - @echo " xml to make Docutils-native XML files" - @echo " pseudoxml to make pseudoxml-XML files for display purposes" - @echo " linkcheck to check all external links for integrity" - @echo " doctest to run all doctests embedded in the documentation (if enabled)" - @echo " coverage to run coverage check of the documentation (if enabled)" - @echo " dummy to check syntax errors of document sources" - -.PHONY: clean -clean: - rm -rf $(BUILDDIR) - -.PHONY: html -html: - $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR) - @echo - @echo "Build finished. The HTML pages are in $(BUILDDIR)" - -.PHONY: dirhtml -dirhtml: - $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml - @echo - @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." - -.PHONY: singlehtml -singlehtml: - $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml - @echo - @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." - -.PHONY: pickle -pickle: - $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle - @echo - @echo "Build finished; now you can process the pickle files." - -.PHONY: json -json: - $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json - @echo - @echo "Build finished; now you can process the JSON files." - -.PHONY: htmlhelp -htmlhelp: - $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp - @echo - @echo "Build finished; now you can run HTML Help Workshop with the" \ - ".hhp project file in $(BUILDDIR)/htmlhelp." - -.PHONY: qthelp -qthelp: - $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp - @echo - @echo "Build finished; now you can run "qcollectiongenerator" with the" \ - ".qhcp project file in $(BUILDDIR)/qthelp, like this:" - @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/atm.qhcp" - @echo "To view the help file:" - @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/atm.qhc" - -.PHONY: applehelp -applehelp: - $(SPHINXBUILD) -b applehelp $(ALLSPHINXOPTS) $(BUILDDIR)/applehelp - @echo - @echo "Build finished. The help book is in $(BUILDDIR)/applehelp." - @echo "N.B. You won't be able to view it unless you put it in" \ - "~/Library/Documentation/Help or install it in your application" \ - "bundle." - -.PHONY: devhelp -devhelp: - $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp - @echo - @echo "Build finished." - @echo "To view the help file:" - @echo "# mkdir -p $$HOME/.local/share/devhelp/atm" - @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/atm" - @echo "# devhelp" - -.PHONY: epub -epub: - $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub - @echo - @echo "Build finished. The epub file is in $(BUILDDIR)/epub." - -.PHONY: epub3 -epub3: - $(SPHINXBUILD) -b epub3 $(ALLSPHINXOPTS) $(BUILDDIR)/epub3 - @echo - @echo "Build finished. The epub3 file is in $(BUILDDIR)/epub3." - -.PHONY: latex -latex: - $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex - @echo - @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." - @echo "Run \`make' in that directory to run these through (pdf)latex" \ - "(use \`make latexpdf' here to do that automatically)." - -.PHONY: latexpdf -latexpdf: - $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex - @echo "Running LaTeX files through pdflatex..." - $(MAKE) -C $(BUILDDIR)/latex all-pdf - @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." - -.PHONY: latexpdfja -latexpdfja: - $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex - @echo "Running LaTeX files through platex and dvipdfmx..." - $(MAKE) -C $(BUILDDIR)/latex all-pdf-ja - @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." - -.PHONY: text -text: - $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text - @echo - @echo "Build finished. The text files are in $(BUILDDIR)/text." - -.PHONY: man -man: - $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man - @echo - @echo "Build finished. The manual pages are in $(BUILDDIR)/man." - -.PHONY: texinfo -texinfo: - $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo - @echo - @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." - @echo "Run \`make' in that directory to run these through makeinfo" \ - "(use \`make info' here to do that automatically)." - -.PHONY: info -info: - $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo - @echo "Running Texinfo files through makeinfo..." - make -C $(BUILDDIR)/texinfo info - @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." - -.PHONY: gettext -gettext: - $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale - @echo - @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." - -.PHONY: changes -changes: - $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes - @echo - @echo "The overview file is in $(BUILDDIR)/changes." - -.PHONY: linkcheck -linkcheck: - $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck - @echo - @echo "Link check complete; look for any errors in the above output " \ - "or in $(BUILDDIR)/linkcheck/output.txt." - -.PHONY: doctest -doctest: - $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest - @echo "Testing of doctests in the sources finished, look at the " \ - "results in $(BUILDDIR)/doctest/output.txt." - -.PHONY: coverage -coverage: - $(SPHINXBUILD) -b coverage $(ALLSPHINXOPTS) $(BUILDDIR)/coverage - @echo "Testing of coverage in the sources finished, look at the " \ - "results in $(BUILDDIR)/coverage/python.txt." - -.PHONY: xml -xml: - $(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml - @echo - @echo "Build finished. The XML files are in $(BUILDDIR)/xml." + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -.PHONY: pseudoxml -pseudoxml: - $(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml - @echo - @echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml." +.PHONY: help Makefile -.PHONY: dummy -dummy: - $(SPHINXBUILD) -b dummy $(ALLSPHINXOPTS) $(BUILDDIR)/dummy - @echo - @echo "Build finished. Dummy builder generates no files." +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/source/add_method.rst b/docs/add_method.rst similarity index 100% rename from docs/source/add_method.rst rename to docs/add_method.rst diff --git a/docs/source/add_to_btb.rst b/docs/add_to_btb.rst similarity index 100% rename from docs/source/add_to_btb.rst rename to docs/add_to_btb.rst diff --git a/docs/api.rst b/docs/api.rst new file mode 100644 index 0000000..7040425 --- /dev/null +++ b/docs/api.rst @@ -0,0 +1 @@ +.. mdinclude:: ../API.md diff --git a/docs/conf.py b/docs/conf.py new file mode 100644 index 0000000..a69888c --- /dev/null +++ b/docs/conf.py @@ -0,0 +1,171 @@ +# -*- coding: utf-8 -*- +# +# atm documentation build configuration file, created by +# sphinx-quickstart on Fri Jan 6 13:06:48 2017. +# +# This file is execfile()d with the current directory set to its +# containing dir. +# +# Note that not all possible configuration values are present in this +# autogenerated file. +# +# All configuration values have a default; values that are commented out +# serve to show the default. + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. + +import sphinx_rtd_theme # For read the docs theme + +import atm + +# -- General configuration --------------------------------------------- + +# If your documentation needs a minimal Sphinx version, state it here. +# +# needs_sphinx = '1.0' + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones. +extensions = [ + 'm2r', + 'sphinx.ext.autodoc', + 'sphinx.ext.githubpages', + 'sphinx.ext.viewcode', + 'sphinx.ext.napoleon', +] + +# Add any paths that contain templates here, relative to this directory. +templates_path = ['_templates'] + +# The suffix(es) of source filenames. +# You can specify multiple suffix as a list of string: +source_suffix = ['.rst', '.md'] + +# The master toctree document. +master_doc = 'index' + + + +# General information about the project. +project = 'ATM' +slug = 'atm' +title = project + ' Documentation' +copyright = '2019, MIT Data to AI Lab' +author = 'Thomas Swearingen, Kalyan Veeramachaneni, Bennett Cyphers' +description = 'ATM: Auto Tune Models' +user = 'HDI-project' + +# The version info for the project you're documenting, acts as replacement +# for |version| and |release|, also used in various other places throughout +# the built documents. +# +# The short X.Y version. +version = atm.__version__ +# The full version, including alpha/beta/rc tags. +release = atm.__version__ + +# The language for content autogenerated by Sphinx. Refer to documentation +# for a list of supported languages. +# +# This is also used if you do content translation via gettext catalogs. +# Usually you set "language" from the command line for these cases. +language = None + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This patterns also effect to html_static_path and html_extra_path +exclude_patterns = ['.py', '_build', 'Thumbs.db', '.DS_Store', '**.ipynb_checkpoints'] + +# The name of the Pygments (syntax highlighting) style to use. +pygments_style = 'sphinx' + +# If true, `todo` and `todoList` produce output, else they produce nothing. +todo_include_todos = False + + +# -- Options for HTML output ------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +# +html_theme = 'sphinx_rtd_theme' +html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] + +# Readthedocs additions +html_context = { + 'display_github': True, + 'github_user': user, + 'github_repo': project, + 'github_version': 'master', + 'conf_py_path': '/docs/', +} + +# Theme options are theme-specific and customize the look and feel of a +# theme further. For a list of options available for each theme, see the +# documentation. +html_theme_options = { + 'collapse_navigation': False, + 'display_version': False, +} + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". + +# The name of an image file (relative to this directory) to use as a favicon of +# the docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 +# pixels large. +html_favicon = 'images/favicon.ico' + +# -- Options for HTMLHelp output --------------------------------------- + +# Output file base name for HTML help builder. +htmlhelp_basename = slug + 'doc' + + +# -- Options for LaTeX output ------------------------------------------ + +latex_elements = { +} + +# Grouping the document tree into LaTeX files. List of tuples +# (source start file, target name, title, author, documentclass +# [howto, manual, or own class]). +latex_documents = [( + master_doc, + slug + '.tex', + title, + author, + 'manual' +)] + + +# -- Options for manual page output ------------------------------------ + +# One entry per manual page. List of tuples +# (source start file, name, description, authors, manual section). +man_pages = [( + master_doc, + slug, + title, + [author], + 1 +)] + + +# -- Options for Texinfo output ---------------------------------------- + +# Grouping the document tree into Texinfo files. List of tuples +# (source start file, target name, title, author, +# dir menu entry, description, category) +texinfo_documents = [( + master_doc, + slug, + title, + author, + slug, + description, + 'Miscellaneous' +)] diff --git a/docs/source/contributing.rst b/docs/contributing.rst similarity index 100% rename from docs/source/contributing.rst rename to docs/contributing.rst diff --git a/docs/source/database.rst b/docs/database.rst similarity index 100% rename from docs/source/database.rst rename to docs/database.rst diff --git a/docs/source/index.rst b/docs/index.rst similarity index 98% rename from docs/source/index.rst rename to docs/index.rst index fb99d67..78d2f93 100644 --- a/docs/source/index.rst +++ b/docs/index.rst @@ -15,6 +15,7 @@ Contents: setup quickstart database + api contributing add_method add_to_btb diff --git a/docs/source/introduction.rst b/docs/introduction.rst similarity index 100% rename from docs/source/introduction.rst rename to docs/introduction.rst diff --git a/docs/source/quickstart.rst b/docs/quickstart.rst similarity index 100% rename from docs/source/quickstart.rst rename to docs/quickstart.rst diff --git a/docs/source/setup.rst b/docs/setup.rst similarity index 100% rename from docs/source/setup.rst rename to docs/setup.rst diff --git a/docs/source/conf.py b/docs/source/conf.py deleted file mode 100644 index f81232f..0000000 --- a/docs/source/conf.py +++ /dev/null @@ -1,437 +0,0 @@ -# -*- coding: utf-8 -*- -# -# atm documentation build configuration file, created by -# sphinx-quickstart on Fri Jan 6 13:06:48 2017. -# -# This file is execfile()d with the current directory set to its -# containing dir. -# -# Note that not all possible configuration values are present in this -# autogenerated file. -# -# All configuration values have a default; values that are commented out -# serve to show the default. - -# If extensions (or modules to document with autodoc) are in another directory, -# add these directories to sys.path here. If the directory is relative to the -# documentation root, use os.path.abspath to make it absolute, like shown here. - -import os -import sys -import sphinx_rtd_theme # For read the docs theme -sys.path.insert(0, os.path.abspath('.')) -sys.path.insert(0, os.path.abspath('..')) -sys.path.insert(0, os.path.abspath('../..')) - - -# -- General configuration ------------------------------------------------ - -# If your documentation needs a minimal Sphinx version, state it here. -# -# needs_sphinx = '1.0' - -# Add any Sphinx extension module names here, as strings. They can be -# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom -# ones. -extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.mathjax', - 'sphinx.ext.githubpages', -] - -autosummary_generate = True - -# Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] - -# The suffix(es) of source filenames. -# You can specify multiple suffix as a list of string: -# -# source_suffix = ['.rst', '.md'] -source_suffix = '.rst' - -# The encoding of source files. -# -# source_encoding = 'utf-8-sig' - -# The master toctree document. -master_doc = 'index' - -# General information about the project. -project = u'atm' -copyright = u'MIT Data to AI Lab' -author = u'Thomas Swearingen, Kalyan Veeramachaneni, Bennett Cyphers' - -# The version info for the project you're documenting, acts as replacement for -# |version| and |release|, also used in various other places throughout the -# built documents. -# -# The short X.Y version. -version = u'0.0.1' -# The full version, including alpha/beta/rc tags. -release = u'0.0.1' - -# The language for content autogenerated by Sphinx. Refer to documentation -# for a list of supported languages. -# -# This is also used if you do content translation via gettext catalogs. -# Usually you set "language" from the command line for these cases. -language = None - -# There are two options for replacing |today|: either, you set today to some -# non-false value, then it is used: -# -# today = '' -# -# Else, today_fmt is used as the format for a strftime call. -# -# today_fmt = '%B %d, %Y' - -# List of patterns, relative to source directory, that match files and -# directories to ignore when looking for source files. -# This patterns also effect to html_static_path and html_extra_path -exclude_patterns = [] - -# The reST default role (used for this markup: `text`) to use for all -# documents. -# -# default_role = None - -# If true, '()' will be appended to :func: etc. cross-reference text. -# -# add_function_parentheses = True - -# If true, the current module name will be prepended to all description -# unit titles (such as .. function::). -# -# add_module_names = True - -# If true, sectionauthor and moduleauthor directives will be shown in the -# output. They are ignored by default. -# -# show_authors = False - -# The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' - -# A list of ignored prefixes for module index sorting. -# modindex_common_prefix = [] - -# If true, keep warnings as "system message" paragraphs in the built documents. -# keep_warnings = False - -# If true, `todo` and `todoList` produce output, else they produce nothing. -todo_include_todos = False - - -# -- Options for HTML output ---------------------------------------------- - -# The theme to use for HTML and HTML Help pages. See the documentation for -# a list of builtin themes. -# -html_theme = 'nature' - -#### FOR READ THE DOCS THEME (optional) -html_theme = "sphinx_rtd_theme" -html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] - -# Theme options are theme-specific and customize the look and feel of a theme -# further. For a list of options available for each theme, see the -# documentation. -# -# html_theme_options = {} - -# Add any paths that contain custom themes here, relative to this directory. -# html_theme_path = [] - -# The name for this set of Sphinx documents. -# " v documentation" by default. -# -# html_title = u'atm v0.9' - -# A shorter title for the navigation bar. Default is the same as html_title. -# -# html_short_title = None - -# The name of an image file (relative to this directory) to place at the top -# of the sidebar. -# -# html_logo = None - -# The name of an image file (relative to this directory) to use as a favicon of -# the docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 -# pixels large. -# -# html_favicon = None - -# Add any paths that contain custom static files (such as style sheets) here, -# relative to this directory. They are copied after the builtin static files, -# so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] - -# Add any extra paths that contain custom files (such as robots.txt or -# .htaccess) here, relative to this directory. These files are copied -# directly to the root of the documentation. -# -# html_extra_path = [] - -# If not None, a 'Last updated on:' timestamp is inserted at every page -# bottom, using the given strftime format. -# The empty string is equivalent to '%b %d, %Y'. -# -# html_last_updated_fmt = None - -# If true, SmartyPants will be used to convert quotes and dashes to -# typographically correct entities. -# -# html_use_smartypants = True - -# Custom sidebar templates, maps document names to template names. -# -# html_sidebars = {} - -# Additional templates that should be rendered to pages, maps page names to -# template names. -# -# html_additional_pages = {} - -# If false, no module index is generated. -# -# html_domain_indices = True - -# If false, no index is generated. -# -# html_use_index = True - -# If true, the index is split into individual pages for each letter. -# -# html_split_index = False - -# If true, links to the reST sources are added to the pages. -# -# html_show_sourcelink = True - -# If true, "Created using Sphinx" is shown in the HTML footer. Default is True. -# -# html_show_sphinx = True - -# If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. -# -# html_show_copyright = True - -# If true, an OpenSearch description file will be output, and all pages will -# contain a tag referring to it. The value of this option must be the -# base URL from which the finished HTML is served. -# -# html_use_opensearch = '' - -# This is the file name suffix for HTML files (e.g. ".xhtml"). -# html_file_suffix = None - -# Language to be used for generating the HTML full-text search index. -# Sphinx supports the following languages: -# 'da', 'de', 'en', 'es', 'fi', 'fr', 'hu', 'it', 'ja' -# 'nl', 'no', 'pt', 'ro', 'ru', 'sv', 'tr', 'zh' -# -# html_search_language = 'en' - -# A dictionary with options for the search language support, empty by default. -# 'ja' uses this config value. -# 'zh' user can custom change `jieba` dictionary path. -# -# html_search_options = {'type': 'default'} - -# The name of a javascript file (relative to the configuration directory) that -# implements a search results scorer. If empty, the default will be used. -# -# html_search_scorer = 'scorer.js' - -# Output file base name for HTML help builder. -htmlhelp_basename = 'atmdoc' - -# -- Options for LaTeX output --------------------------------------------- - -latex_elements = { - # The paper size ('letterpaper' or 'a4paper'). - # - # 'papersize': 'letterpaper', - - # The font size ('10pt', '11pt' or '12pt'). - # - # 'pointsize': '10pt', - - # Additional stuff for the LaTeX preamble. - # - # 'preamble': '', - - # Latex figure (float) alignment - # - # 'figure_align': 'htbp', -} - -# Grouping the document tree into LaTeX files. List of tuples -# (source start file, target name, title, -# author, documentclass [howto, manual, or own class]). -latex_documents = [ - (master_doc, 'atm.tex', u'ATM Documentation', - author, 'manual'), -] - -# The name of an image file (relative to this directory) to place at the top of -# the title page. -# -# latex_logo = None - -# For "manual" documents, if this is true, then toplevel headings are parts, -# not chapters. -# -# latex_use_parts = False - -# If true, show page references after internal links. -# -# latex_show_pagerefs = False - -# If true, show URL addresses after external links. -# -# latex_show_urls = False - -# Documents to append as an appendix to all manuals. -# -# latex_appendices = [] - -# It false, will not define \strong, \code, itleref, \crossref ... but only -# \sphinxstrong, ..., \sphinxtitleref, ... To help avoid clash with user added -# packages. -# -# latex_keep_old_macro_names = True - -# If false, no module index is generated. -# -# latex_domain_indices = True - - -# -- Options for manual page output --------------------------------------- - -# One entry per manual page. List of tuples -# (source start file, name, description, authors, manual section). -man_pages = [ - (master_doc, 'atm', u'ATM Documentation', - author.split(','), 1) -] - -# If true, show URL addresses after external links. -# -# man_show_urls = False - - -# -- Options for Texinfo output ------------------------------------------- - -# Grouping the document tree into Texinfo files. List of tuples -# (source start file, target name, title, author, -# dir menu entry, description, category) -texinfo_documents = [ - (master_doc, 'atm', u'ATM Documentation', - author, 'atm', 'One line description of project.', - 'Miscellaneous'), -] - -# Documents to append as an appendix to all manuals. -# -# texinfo_appendices = [] - -# If false, no module index is generated. -# -# texinfo_domain_indices = True - -# How to display URL addresses: 'footnote', 'no', or 'inline'. -# -# texinfo_show_urls = 'footnote' - -# If true, do not generate a @detailmenu in the "Top" node's menu. -# -# texinfo_no_detailmenu = False - - -# -- Options for Epub output ---------------------------------------------- - -# Bibliographic Dublin Core info. -epub_title = project -epub_author = author -epub_publisher = author -epub_copyright = copyright - -# The basename for the epub file. It defaults to the project name. -# epub_basename = project - -# The HTML theme for the epub output. Since the default themes are not -# optimized for small screen space, using the same theme for HTML and epub -# output is usually not wise. This defaults to 'epub', a theme designed to save -# visual space. -# -# epub_theme = 'epub' - -# The language of the text. It defaults to the language option -# or 'en' if the language is not set. -# -# epub_language = '' - -# The scheme of the identifier. Typical schemes are ISBN or URL. -# epub_scheme = '' - -# The unique identifier of the text. This can be a ISBN number -# or the project homepage. -# -# epub_identifier = '' - -# A unique identification for the text. -# -# epub_uid = '' - -# A tuple containing the cover image and cover page html template filenames. -# -# epub_cover = () - -# A sequence of (type, uri, title) tuples for the guide element of content.opf. -# -# epub_guide = () - -# HTML files that should be inserted before the pages created by sphinx. -# The format is a list of tuples containing the path and title. -# -# epub_pre_files = [] - -# HTML files that should be inserted after the pages created by sphinx. -# The format is a list of tuples containing the path and title. -# -# epub_post_files = [] - -# A list of files that should not be packed into the epub file. -epub_exclude_files = ['search.html'] - -# The depth of the table of contents in toc.ncx. -# -# epub_tocdepth = 3 - -# Allow duplicate toc entries. -# -# epub_tocdup = True - -# Choose between 'default' and 'includehidden'. -# -# epub_tocscope = 'default' - -# Fix unsupported image types using the Pillow. -# -# epub_fix_images = False - -# Scale large images. -# -# epub_max_image_width = 0 - -# How to display URL addresses: 'footnote', 'no', or 'inline'. -# -# epub_show_urls = 'inline' - -# If false, no index is generated. -# -# epub_use_index = True diff --git a/docs/source/tutorial.rst b/docs/tutorial.rst similarity index 100% rename from docs/source/tutorial.rst rename to docs/tutorial.rst diff --git a/setup.cfg b/setup.cfg index 6b17b69..f47bc4c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,15 +1,15 @@ [bumpversion] -current_version = 0.1.1 +current_version = 0.1.2-dev commit = True tag = True parse = (?P\d+)\.(?P\d+)\.(?P\d+)(\-(?P[a-z]+))? -serialize = +serialize = {major}.{minor}.{patch}-{release} {major}.{minor}.{patch} [bumpversion:part:release] optional_value = release -values = +values = dev release @@ -26,7 +26,7 @@ universal = 1 [flake8] max-line-length = 99 -exclude = docs, .git, __pycache__, .ipynb_checkpoints +exclude = docs, .tox, .git, __pycache__, .ipynb_checkpoints ignore = # Keep empty to prevent default ignores [isort] diff --git a/setup.py b/setup.py index d3ae04d..0ae2840 100644 --- a/setup.py +++ b/setup.py @@ -17,12 +17,19 @@ 'mysqlclient>=1.2', 'numpy>=1.13.1', 'pandas>=0.22.0', + 'psutil>=5.6.1', + 'python-daemon>=2.2.3', 'pyyaml>=3.12', 'requests>=2.18.4', 'scikit-learn>=0.18.2', 'scipy>=0.19.1', 'sklearn-pandas>=1.5.0', 'sqlalchemy>=1.1.14', + 'flask>=1.0.2', + 'flask-restless>=0.17.0', + 'flask-sqlalchemy>=2.3.2', + 'flask-restless-swagger-2>=0.0.3', + 'simplejson>=3.16.0', ] setup_requires = [ @@ -35,6 +42,7 @@ 'pytest-runner>=3.0', 'pytest-xdist>=1.20.1', 'pytest>=3.2.3', + 'google-compute-engine==2.8.12', # required by travis ] development_requires = [ @@ -102,6 +110,6 @@ test_suite='tests', tests_require=tests_require, url='https://github.com/HDI-project/ATM', - version='0.1.1', + version='0.1.2-dev', zip_safe=False, ) diff --git a/tests/test_enter_data.py b/tests/test_core.py similarity index 66% rename from tests/test_enter_data.py rename to tests/test_core.py index 7dfb303..69a83c1 100644 --- a/tests/test_enter_data.py +++ b/tests/test_core.py @@ -3,9 +3,9 @@ import pytest from atm import PROJECT_ROOT -from atm.config import RunConfig, SQLConfig +from atm.config import DatasetConfig, RunConfig, SQLConfig +from atm.core import ATM from atm.database import Database, db_session -from atm.enter_data import create_dataset, enter_data from atm.utilities import get_local_data_path DB_PATH = '/tmp/atm.db' @@ -52,6 +52,8 @@ def test_create_dataset(db): train_url = DATA_URL + 'pollution_1_train.csv' test_url = DATA_URL + 'pollution_1_test.csv' + sql_conf = SQLConfig({'sql_database': DB_PATH}) + train_path_local, _ = get_local_data_path(train_url) if os.path.exists(train_path_local): os.remove(train_path_local) @@ -60,11 +62,16 @@ def test_create_dataset(db): if os.path.exists(test_path_local): os.remove(test_path_local) - run_conf = RunConfig(train_path=train_url, - test_path=test_url, - data_description='test', - class_column='class') - dataset = create_dataset(db, run_conf) + dataset_conf = DatasetConfig({ + 'train_path': train_url, + 'test_path': test_url, + 'data_description': 'test', + 'class_column': 'class' + }) + + atm = ATM(sql_conf, None, None) + + dataset = atm.create_dataset(dataset_conf) dataset = db.get_dataset(dataset.id) assert os.path.exists(train_path_local) @@ -81,46 +88,57 @@ def test_create_dataset(db): def test_enter_data_by_methods(dataset): - sql_conf = SQLConfig(database=DB_PATH) - db = Database(**vars(sql_conf)) - run_conf = RunConfig(dataset_id=dataset.id) + sql_conf = SQLConfig({'sql_database': DB_PATH}) + db = Database(**sql_conf.to_dict()) + run_conf = RunConfig({'dataset_id': dataset.id}) + + atm = ATM(sql_conf, None, None) for method, n_parts in METHOD_HYPERPARTS.items(): run_conf.methods = [method] - run_id = enter_data(sql_conf, run_conf) + run_id = atm.enter_data(None, run_conf) - assert db.get_datarun(run_id) with db_session(db): - run = db.get_datarun(run_id) + run = db.get_datarun(run_id.id) assert run.dataset.id == dataset.id assert len(run.hyperpartitions) == n_parts def test_enter_data_all(dataset): - sql_conf = SQLConfig(database=DB_PATH) - db = Database(**vars(sql_conf)) - run_conf = RunConfig(dataset_id=dataset.id, - methods=METHOD_HYPERPARTS.keys()) + sql_conf = SQLConfig({'sql_database': DB_PATH}) + db = Database(**sql_conf.to_dict()) + run_conf = RunConfig({'dataset_id': dataset.id, 'methods': METHOD_HYPERPARTS.keys()}) - run_id = enter_data(sql_conf, run_conf) + atm = ATM(sql_conf, None, None) + + run_id = atm.enter_data(None, run_conf) with db_session(db): - run = db.get_datarun(run_id) + run = db.get_datarun(run_id.id) assert run.dataset.id == dataset.id assert len(run.hyperpartitions) == sum(METHOD_HYPERPARTS.values()) def test_run_per_partition(dataset): - sql_conf = SQLConfig(database=DB_PATH) - db = Database(**vars(sql_conf)) - run_conf = RunConfig(dataset_id=dataset.id, methods=['logreg']) + sql_conf = SQLConfig({'sql_database': DB_PATH}) + db = Database(**sql_conf.to_dict()) + + run_conf = RunConfig( + { + 'dataset_id': dataset.id, + 'methods': ['logreg'], + 'run_per_partition': True + } + ) + + atm = ATM(sql_conf, None, None) - run_ids = enter_data(sql_conf, run_conf, run_per_partition=True) + run_ids = atm.enter_data(None, run_conf) with db_session(db): runs = [] for run_id in run_ids: - run = db.get_datarun(run_id) + run = db.get_datarun(run_id.id) if run is not None: runs.append(run) diff --git a/tests/test_worker.py b/tests/test_worker.py index 2bc5e30..1602a66 100644 --- a/tests/test_worker.py +++ b/tests/test_worker.py @@ -12,10 +12,10 @@ from atm import PROJECT_ROOT from atm.classifier import Model -from atm.config import LogConfig, RunConfig, SQLConfig +from atm.config import DatasetConfig, LogConfig, RunConfig, SQLConfig from atm.constants import METRICS_BINARY, TIME_FMT +from atm.core import ATM from atm.database import Database, db_session -from atm.enter_data import enter_data from atm.utilities import download_data, load_metrics, load_model from atm.worker import ClassifierError, Worker @@ -106,12 +106,19 @@ def worker(db, datarun): def get_new_worker(**kwargs): + kwargs['dataset_id'] = kwargs.get('dataset_id', None) kwargs['methods'] = kwargs.get('methods', ['logreg', 'dt']) - sql_conf = SQLConfig(database=DB_PATH) - run_conf = RunConfig(**kwargs) - run_id = enter_data(sql_conf, run_conf) - db = Database(**vars(sql_conf)) - datarun = db.get_datarun(run_id) + sql_conf = SQLConfig({'sql_database': DB_PATH}) + run_conf = RunConfig(kwargs) + + dataset_conf = DatasetConfig(kwargs) + + db = Database(**sql_conf.to_dict()) + atm = ATM(sql_conf, None, None) + + run_id = atm.enter_data(dataset_conf, run_conf) + datarun = db.get_datarun(run_id.id) + return Worker(db, datarun) @@ -182,7 +189,7 @@ def test_test_classifier(db, dataset): def test_save_classifier(db, datarun, model, metrics): - log_conf = LogConfig(model_dir=MODEL_DIR, metric_dir=METRIC_DIR) + log_conf = LogConfig({'model_dir': MODEL_DIR, 'metric_dir': METRIC_DIR}) worker = Worker(db, datarun, log_config=log_conf) hp = db.get_hyperpartitions(datarun_id=worker.datarun.id)[0] classifier = worker.db.start_classifier(hyperpartition_id=hp.id, diff --git a/tox.ini b/tox.ini index dca9da6..234d290 100644 --- a/tox.ini +++ b/tox.ini @@ -5,26 +5,28 @@ envlist = py27, py35, py36, docs, lint [travis] python = 3.6: py36, docs, lint - 3.5: py35 + 3.5: py35, 2.7: py27 [testenv] +passenv = CI TRAVIS TRAVIS_* setenv = PYTHONPATH = {toxinidir} -deps = - .[dev] +extras = tests commands = /usr/bin/env python setup.py test [testenv:lint] skipsdist = true +extras = dev commands = /usr/bin/env make lint [testenv:docs] skipsdist = true +extras = dev commands = /usr/bin/env make docs