From 86a1898205297c961d0ba3b6c3b6245338630e74 Mon Sep 17 00:00:00 2001 From: Dmytro Trotsko Date: Mon, 26 Jun 2023 18:21:13 +0300 Subject: [PATCH 1/2] Drop phase handling --- src/server/_common.py | 11 +++-- src/server/_config.py | 2 - src/server/_limiter.py | 27 +------------ src/server/_printer.py | 90 +++++------------------------------------ src/server/_security.py | 38 ----------------- 5 files changed, 18 insertions(+), 150 deletions(-) diff --git a/src/server/_common.py b/src/server/_common.py index 56d4c38ec..8257d12f8 100644 --- a/src/server/_common.py +++ b/src/server/_common.py @@ -11,7 +11,7 @@ from ._config import SECRET, REVERSE_PROXY_DEPTH from ._db import engine from ._exceptions import DatabaseErrorException, EpiDataException -from ._security import current_user, _is_public_route, resolve_auth_token, show_no_api_key_warning, update_key_last_time_used, ERROR_MSG_INVALID_KEY +from ._security import current_user, _is_public_route, resolve_auth_token, update_key_last_time_used, ERROR_MSG_INVALID_KEY app = Flask("EpiData", static_url_path="") @@ -127,11 +127,10 @@ def before_request_execute(): user_id=(user and user.id) ) - if not show_no_api_key_warning(): - if not _is_public_route() and api_key and not user: - # if this is a privleged endpoint, and an api key was given but it does not look up to a user, raise exception: - get_structured_logger("server_api").info("bad api key used", api_key=api_key) - raise Unauthorized(ERROR_MSG_INVALID_KEY) + if not _is_public_route() and api_key and not user: + # if this is a privleged endpoint, and an api key was given but it does not look up to a user, raise exception: + get_structured_logger("server_api").info("bad api key used", api_key=api_key) + raise Unauthorized(ERROR_MSG_INVALID_KEY) if request.path.startswith("/lib"): # files served from 'lib' directory don't need the database, so we can exit this early... diff --git a/src/server/_config.py b/src/server/_config.py index 168512a3d..2d4298a62 100644 --- a/src/server/_config.py +++ b/src/server/_config.py @@ -85,8 +85,6 @@ } NATION_REGION = "nat" -API_KEY_REQUIRED_STARTING_AT = date.fromisoformat(os.environ.get("API_KEY_REQUIRED_STARTING_AT", "2023-06-21")) -TEMPORARY_API_KEY = os.environ.get("TEMPORARY_API_KEY", "TEMP-API-KEY-EXPIRES-2023-06-28") # password needed for the admin application if not set the admin routes won't be available ADMIN_PASSWORD = os.environ.get("API_KEY_ADMIN_PASSWORD", "abc") # secret for the google form to give to the admin/register endpoint diff --git a/src/server/_limiter.py b/src/server/_limiter.py index 4bf72e05b..3d84b333c 100644 --- a/src/server/_limiter.py +++ b/src/server/_limiter.py @@ -1,5 +1,5 @@ from delphi.epidata.server.endpoints.covidcast_utils.dashboard_signals import DashboardSignals -from flask import Response, request, make_response, jsonify +from flask import Response, request from flask_limiter import Limiter, HEADERS from redis import Redis from werkzeug.exceptions import Unauthorized, TooManyRequests @@ -8,7 +8,7 @@ from ._config import RATE_LIMIT, RATELIMIT_STORAGE_URL, REDIS_HOST, REDIS_PASSWORD from ._exceptions import ValidationFailedException from ._params import extract_dates, extract_integers, extract_strings -from ._security import _is_public_route, current_user, require_api_key, show_no_api_key_warning, resolve_auth_token, ERROR_MSG_RATE_LIMIT, ERROR_MSG_MULTIPLES +from ._security import _is_public_route, current_user, resolve_auth_token, ERROR_MSG_RATE_LIMIT, ERROR_MSG_MULTIPLES def deduct_on_success(response: Response) -> bool: @@ -108,23 +108,8 @@ def ratelimit_handler(e): return TooManyRequests(ERROR_MSG_RATE_LIMIT) -def requests_left(): - r = Redis(host=REDIS_HOST, password=REDIS_PASSWORD) - allowed_count, period = RATE_LIMIT.split("/") - try: - remaining_count = int(allowed_count) - int( - r.get(f"LIMITER/{_resolve_tracking_key()}/EpidataLimiter/{allowed_count}/1/{period}") - ) - except TypeError: - return 1 - return remaining_count - - @limiter.request_filter def _no_rate_limit() -> bool: - if show_no_api_key_warning(): - # no rate limit in phase 0 - return True if _is_public_route(): # no rate limit for public routes return True @@ -132,14 +117,6 @@ def _no_rate_limit() -> bool: # no rate limit if user is registered return True - if not require_api_key(): - # we are in phase 1 or 2 - if requests_left() > 0: - # ...and user is below rate limit, we still want to record this query for the rate computation... - return False - # ...otherwise, they have exceeded the limit, but we still want to allow them through - return True - # phase 3 (full api-keys behavior) multiples = get_multiples_count(request) if multiples < 0: diff --git a/src/server/_printer.py b/src/server/_printer.py index 6e32d7d43..b7dfd1461 100644 --- a/src/server/_printer.py +++ b/src/server/_printer.py @@ -2,15 +2,12 @@ from io import StringIO from typing import Any, Dict, Iterable, List, Optional, Union -from flask import Response, jsonify, stream_with_context, request +from flask import Response, jsonify, stream_with_context from flask.json import dumps import orjson from ._config import MAX_RESULTS, MAX_COMPATIBILITY_RESULTS -# TODO: remove warnings after once we are past the API_KEY_REQUIRED_STARTING_AT date -from ._security import show_hard_api_key_warning, show_soft_api_key_warning, ROLLOUT_WARNING_RATE_LIMIT, ROLLOUT_WARNING_MULTIPLES, _ROLLOUT_WARNING_AD_FRAGMENT, PHASE_1_2_STOPGAP from ._common import is_compatibility_mode, log_info_with_request -from ._limiter import requests_left, get_multiples_count from delphi.epidata.common.logger import get_structured_logger @@ -25,15 +22,7 @@ def print_non_standard(format: str, data): message = "no results" result = -2 else: - warning = "" - if show_hard_api_key_warning(): - if requests_left() == 0: - warning = f"{ROLLOUT_WARNING_RATE_LIMIT}" - if get_multiples_count(request) < 0: - warning = f"{warning} {ROLLOUT_WARNING_MULTIPLES}" - if requests_left() == 0 or get_multiples_count(request) < 0: - warning = f"{warning} {_ROLLOUT_WARNING_AD_FRAGMENT} {PHASE_1_2_STOPGAP}" - message = warning.strip() or "success" + message = "success" result = 1 if result == -1 and is_compatibility_mode(): return jsonify(dict(result=result, message=message)) @@ -126,40 +115,21 @@ class ClassicPrinter(APrinter): """ def _begin(self): - if is_compatibility_mode() and not show_hard_api_key_warning(): + if is_compatibility_mode(): return "{ " - r = '{ "epidata": [' - if show_hard_api_key_warning(): - warning = "" - if requests_left() == 0: - warning = f"{warning} {ROLLOUT_WARNING_RATE_LIMIT}" - if get_multiples_count(request) < 0: - warning = f"{warning} {ROLLOUT_WARNING_MULTIPLES}" - if requests_left() == 0 or get_multiples_count(request) < 0: - warning = f"{warning} {_ROLLOUT_WARNING_AD_FRAGMENT} {PHASE_1_2_STOPGAP}" - if warning != "": - return f'{r} "{warning.strip()}",' - return r + return '{ "epidata": [' def _format_row(self, first: bool, row: Dict): - if first and is_compatibility_mode() and not show_hard_api_key_warning(): + if first and is_compatibility_mode(): sep = b'"epidata": [' else: sep = b"," if not first else b"" return sep + orjson.dumps(row) def _end(self): - warning = "" - if show_soft_api_key_warning(): - if requests_left() == 0: - warning = f"{warning} {ROLLOUT_WARNING_RATE_LIMIT}" - if get_multiples_count(request) < 0: - warning = f"{warning} {ROLLOUT_WARNING_MULTIPLES}" - if requests_left() == 0 or get_multiples_count(request) < 0: - warning = f"{warning} {_ROLLOUT_WARNING_AD_FRAGMENT} {PHASE_1_2_STOPGAP}" - message = warning.strip() or "success" + message = "success" prefix = "], " - if self.count == 0 and is_compatibility_mode() and not show_hard_api_key_warning(): + if self.count == 0 and is_compatibility_mode(): # no array to end prefix = "" @@ -193,7 +163,7 @@ def _format_row(self, first: bool, row: Dict): self._tree[group].append(row) else: self._tree[group] = [row] - if first and is_compatibility_mode() and not show_hard_api_key_warning(): + if first and is_compatibility_mode(): return b'"epidata": [' return None @@ -204,10 +174,7 @@ def _end(self): tree = orjson.dumps(self._tree) self._tree = dict() r = super(ClassicTreePrinter, self)._end() - r = tree + r - if show_hard_api_key_warning() and (requests_left() == 0 or get_multiples_count(request) < 0): - r = b", " + r - return r + return tree + r class CSVPrinter(APrinter): @@ -239,17 +206,6 @@ def _format_row(self, first: bool, row: Dict): columns = list(row.keys()) self._writer = DictWriter(self._stream, columns, lineterminator="\n") self._writer.writeheader() - if show_hard_api_key_warning(): - warning = "" - if requests_left() == 0: - warning = f"{warning} {ROLLOUT_WARNING_RATE_LIMIT}" - if get_multiples_count(request) < 0: - warning = f"{warning} {ROLLOUT_WARNING_MULTIPLES}" - if requests_left() == 0 or get_multiples_count(request) < 0: - warning = f"{warning} {_ROLLOUT_WARNING_AD_FRAGMENT} {PHASE_1_2_STOPGAP}" - if warning.strip() != "": - self._writer.writerow({columns[0]: warning}) - self._writer.writerow(row) # remove the stream content to print just one line at a time @@ -270,18 +226,7 @@ class JSONPrinter(APrinter): """ def _begin(self): - r = b"[" - if show_hard_api_key_warning(): - warning = "" - if requests_left() == 0: - warning = f"{warning} {ROLLOUT_WARNING_RATE_LIMIT}" - if get_multiples_count(request) < 0: - warning = f"{warning} {ROLLOUT_WARNING_MULTIPLES}" - if requests_left() == 0 or get_multiples_count(request) < 0: - warning = f"{warning} {_ROLLOUT_WARNING_AD_FRAGMENT} {PHASE_1_2_STOPGAP}" - if warning.strip() != "": - r = b'["' + bytes(warning, "utf-8") + b'",' - return r + return b"[" def _format_row(self, first: bool, row: Dict): sep = b"," if not first else b"" @@ -299,19 +244,6 @@ class JSONLPrinter(APrinter): def make_response(self, gen): return Response(gen, mimetype=" text/plain; charset=utf8") - def _begin(self): - if show_hard_api_key_warning(): - warning = "" - if requests_left() == 0: - warning = f"{warning} {ROLLOUT_WARNING_RATE_LIMIT}" - if get_multiples_count(request) < 0: - warning = f"{warning} {ROLLOUT_WARNING_MULTIPLES}" - if requests_left() == 0 or get_multiples_count(request) < 0: - warning = f"{warning} {_ROLLOUT_WARNING_AD_FRAGMENT} {PHASE_1_2_STOPGAP}" - if warning.strip() != "": - return bytes(warning, "utf-8") + b"\n" - return None - def _format_row(self, first: bool, row: Dict): # each line is a JSON file with a new line to separate them return orjson.dumps(row, option=orjson.OPT_APPEND_NEWLINE) @@ -334,4 +266,4 @@ def create_printer(format: str) -> APrinter: return CSVPrinter() if format == "jsonl": return JSONLPrinter() - return ClassicPrinter() + return ClassicPrinter() \ No newline at end of file diff --git a/src/server/_security.py b/src/server/_security.py index 761d088c3..ab3249846 100644 --- a/src/server/_security.py +++ b/src/server/_security.py @@ -9,27 +9,13 @@ from werkzeug.local import LocalProxy from ._config import ( - API_KEY_REQUIRED_STARTING_AT, REDIS_HOST, REDIS_PASSWORD, API_KEY_REGISTRATION_FORM_LINK_LOCAL, - TEMPORARY_API_KEY, URL_PREFIX, ) from .admin.models import User, UserRole -API_KEY_HARD_WARNING = API_KEY_REQUIRED_STARTING_AT - timedelta(days=14) -API_KEY_SOFT_WARNING = API_KEY_HARD_WARNING - timedelta(days=14) - -# rollout warning messages -ROLLOUT_WARNING_RATE_LIMIT = "This request exceeded the rate limit on anonymous requests, which will be enforced starting {}.".format(API_KEY_REQUIRED_STARTING_AT) -ROLLOUT_WARNING_MULTIPLES = "This request exceeded the anonymous limit on selected multiples, which will be enforced starting {}.".format(API_KEY_REQUIRED_STARTING_AT) -_ROLLOUT_WARNING_AD_FRAGMENT = "To be exempt from this limit, authenticate your requests with a free API key, now available at {}.".format(API_KEY_REGISTRATION_FORM_LINK_LOCAL) - -PHASE_1_2_STOPGAP = ( - "A temporary public key `{}` is available for use between now and {} to give you time to register or adapt your requests without this message continuing to break your systems." -).format(TEMPORARY_API_KEY, (API_KEY_REQUIRED_STARTING_AT + timedelta(days=7))) - # steady-state error messages ERROR_MSG_RATE_LIMIT = "Rate limit exceeded for anonymous queries. To remove this limit, register a free API key at {}".format(API_KEY_REGISTRATION_FORM_LINK_LOCAL) @@ -54,30 +40,6 @@ def resolve_auth_token() -> Optional[str]: return None -def show_no_api_key_warning() -> bool: - # aka "phase 0" - n = date.today() - return not current_user and n < API_KEY_SOFT_WARNING - - -def show_soft_api_key_warning() -> bool: - # aka "phase 1" - n = date.today() - return not current_user and API_KEY_SOFT_WARNING <= n < API_KEY_HARD_WARNING - - -def show_hard_api_key_warning() -> bool: - # aka "phase 2" - n = date.today() - return not current_user and API_KEY_HARD_WARNING <= n < API_KEY_REQUIRED_STARTING_AT - - -def require_api_key() -> bool: - # aka "phase 3" - n = date.today() - return API_KEY_REQUIRED_STARTING_AT <= n - - def _get_current_user(): if "user" not in g: api_key = resolve_auth_token() From 6e6db5de45d10c7dab0399540509fcdb0929bc8d Mon Sep 17 00:00:00 2001 From: Dmytro Trotsko Date: Tue, 27 Jun 2023 17:07:38 +0300 Subject: [PATCH 2/2] Remove "phase" comment Co-authored-by: melange396 --- src/server/_limiter.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/server/_limiter.py b/src/server/_limiter.py index 3d84b333c..489d6a44b 100644 --- a/src/server/_limiter.py +++ b/src/server/_limiter.py @@ -117,7 +117,6 @@ def _no_rate_limit() -> bool: # no rate limit if user is registered return True - # phase 3 (full api-keys behavior) multiples = get_multiples_count(request) if multiples < 0: # too many multiples