Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Drop phase handling #1210

Merged
merged 2 commits into from
Jul 10, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions src/server/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ._config import SECRET, REVERSE_PROXY_DEPTH
from ._db import engine
from ._exceptions import DatabaseErrorException, EpiDataException
from ._security import current_user, _is_public_route, resolve_auth_token, show_no_api_key_warning, update_key_last_time_used, ERROR_MSG_INVALID_KEY
from ._security import current_user, _is_public_route, resolve_auth_token, update_key_last_time_used, ERROR_MSG_INVALID_KEY


app = Flask("EpiData", static_url_path="")
Expand Down Expand Up @@ -127,11 +127,10 @@ def before_request_execute():
user_id=(user and user.id)
)

if not show_no_api_key_warning():
if not _is_public_route() and api_key and not user:
# if this is a privleged endpoint, and an api key was given but it does not look up to a user, raise exception:
get_structured_logger("server_api").info("bad api key used", api_key=api_key)
raise Unauthorized(ERROR_MSG_INVALID_KEY)
if not _is_public_route() and api_key and not user:
# if this is a privleged endpoint, and an api key was given but it does not look up to a user, raise exception:
get_structured_logger("server_api").info("bad api key used", api_key=api_key)
raise Unauthorized(ERROR_MSG_INVALID_KEY)

if request.path.startswith("/lib"):
# files served from 'lib' directory don't need the database, so we can exit this early...
Expand Down
2 changes: 0 additions & 2 deletions src/server/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,6 @@
}
NATION_REGION = "nat"

API_KEY_REQUIRED_STARTING_AT = date.fromisoformat(os.environ.get("API_KEY_REQUIRED_STARTING_AT", "2023-06-21"))
TEMPORARY_API_KEY = os.environ.get("TEMPORARY_API_KEY", "TEMP-API-KEY-EXPIRES-2023-06-28")
# password needed for the admin application if not set the admin routes won't be available
ADMIN_PASSWORD = os.environ.get("API_KEY_ADMIN_PASSWORD", "abc")
# secret for the google form to give to the admin/register endpoint
Expand Down
27 changes: 2 additions & 25 deletions src/server/_limiter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from delphi.epidata.server.endpoints.covidcast_utils.dashboard_signals import DashboardSignals
from flask import Response, request, make_response, jsonify
from flask import Response, request
from flask_limiter import Limiter, HEADERS
from redis import Redis
from werkzeug.exceptions import Unauthorized, TooManyRequests
Expand All @@ -8,7 +8,7 @@
from ._config import RATE_LIMIT, RATELIMIT_STORAGE_URL, REDIS_HOST, REDIS_PASSWORD
from ._exceptions import ValidationFailedException
from ._params import extract_dates, extract_integers, extract_strings
from ._security import _is_public_route, current_user, require_api_key, show_no_api_key_warning, resolve_auth_token, ERROR_MSG_RATE_LIMIT, ERROR_MSG_MULTIPLES
from ._security import _is_public_route, current_user, resolve_auth_token, ERROR_MSG_RATE_LIMIT, ERROR_MSG_MULTIPLES


def deduct_on_success(response: Response) -> bool:
Expand Down Expand Up @@ -108,38 +108,15 @@ def ratelimit_handler(e):
return TooManyRequests(ERROR_MSG_RATE_LIMIT)


def requests_left():
r = Redis(host=REDIS_HOST, password=REDIS_PASSWORD)
allowed_count, period = RATE_LIMIT.split("/")
try:
remaining_count = int(allowed_count) - int(
r.get(f"LIMITER/{_resolve_tracking_key()}/EpidataLimiter/{allowed_count}/1/{period}")
)
except TypeError:
return 1
return remaining_count


@limiter.request_filter
def _no_rate_limit() -> bool:
if show_no_api_key_warning():
# no rate limit in phase 0
return True
if _is_public_route():
# no rate limit for public routes
return True
if current_user:
# no rate limit if user is registered
return True

if not require_api_key():
# we are in phase 1 or 2
if requests_left() > 0:
# ...and user is below rate limit, we still want to record this query for the rate computation...
return False
# ...otherwise, they have exceeded the limit, but we still want to allow them through
return True

# phase 3 (full api-keys behavior)
multiples = get_multiples_count(request)
if multiples < 0:
Expand Down
90 changes: 11 additions & 79 deletions src/server/_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,12 @@
from io import StringIO
from typing import Any, Dict, Iterable, List, Optional, Union

from flask import Response, jsonify, stream_with_context, request
from flask import Response, jsonify, stream_with_context
from flask.json import dumps
import orjson

from ._config import MAX_RESULTS, MAX_COMPATIBILITY_RESULTS
# TODO: remove warnings after once we are past the API_KEY_REQUIRED_STARTING_AT date
from ._security import show_hard_api_key_warning, show_soft_api_key_warning, ROLLOUT_WARNING_RATE_LIMIT, ROLLOUT_WARNING_MULTIPLES, _ROLLOUT_WARNING_AD_FRAGMENT, PHASE_1_2_STOPGAP
from ._common import is_compatibility_mode, log_info_with_request
from ._limiter import requests_left, get_multiples_count
from delphi.epidata.common.logger import get_structured_logger


Expand All @@ -25,15 +22,7 @@ def print_non_standard(format: str, data):
message = "no results"
result = -2
else:
warning = ""
if show_hard_api_key_warning():
if requests_left() == 0:
warning = f"{ROLLOUT_WARNING_RATE_LIMIT}"
if get_multiples_count(request) < 0:
warning = f"{warning} {ROLLOUT_WARNING_MULTIPLES}"
if requests_left() == 0 or get_multiples_count(request) < 0:
warning = f"{warning} {_ROLLOUT_WARNING_AD_FRAGMENT} {PHASE_1_2_STOPGAP}"
message = warning.strip() or "success"
message = "success"
result = 1
if result == -1 and is_compatibility_mode():
return jsonify(dict(result=result, message=message))
Expand Down Expand Up @@ -126,40 +115,21 @@ class ClassicPrinter(APrinter):
"""

def _begin(self):
if is_compatibility_mode() and not show_hard_api_key_warning():
if is_compatibility_mode():
return "{ "
r = '{ "epidata": ['
if show_hard_api_key_warning():
warning = ""
if requests_left() == 0:
warning = f"{warning} {ROLLOUT_WARNING_RATE_LIMIT}"
if get_multiples_count(request) < 0:
warning = f"{warning} {ROLLOUT_WARNING_MULTIPLES}"
if requests_left() == 0 or get_multiples_count(request) < 0:
warning = f"{warning} {_ROLLOUT_WARNING_AD_FRAGMENT} {PHASE_1_2_STOPGAP}"
if warning != "":
return f'{r} "{warning.strip()}",'
return r
return '{ "epidata": ['

def _format_row(self, first: bool, row: Dict):
if first and is_compatibility_mode() and not show_hard_api_key_warning():
if first and is_compatibility_mode():
sep = b'"epidata": ['
else:
sep = b"," if not first else b""
return sep + orjson.dumps(row)

def _end(self):
warning = ""
if show_soft_api_key_warning():
if requests_left() == 0:
warning = f"{warning} {ROLLOUT_WARNING_RATE_LIMIT}"
if get_multiples_count(request) < 0:
warning = f"{warning} {ROLLOUT_WARNING_MULTIPLES}"
if requests_left() == 0 or get_multiples_count(request) < 0:
warning = f"{warning} {_ROLLOUT_WARNING_AD_FRAGMENT} {PHASE_1_2_STOPGAP}"
message = warning.strip() or "success"
message = "success"
prefix = "], "
if self.count == 0 and is_compatibility_mode() and not show_hard_api_key_warning():
if self.count == 0 and is_compatibility_mode():
# no array to end
prefix = ""

Expand Down Expand Up @@ -193,7 +163,7 @@ def _format_row(self, first: bool, row: Dict):
self._tree[group].append(row)
else:
self._tree[group] = [row]
if first and is_compatibility_mode() and not show_hard_api_key_warning():
if first and is_compatibility_mode():
return b'"epidata": ['
return None

Expand All @@ -204,10 +174,7 @@ def _end(self):
tree = orjson.dumps(self._tree)
self._tree = dict()
r = super(ClassicTreePrinter, self)._end()
r = tree + r
if show_hard_api_key_warning() and (requests_left() == 0 or get_multiples_count(request) < 0):
r = b", " + r
return r
return tree + r


class CSVPrinter(APrinter):
Expand Down Expand Up @@ -239,17 +206,6 @@ def _format_row(self, first: bool, row: Dict):
columns = list(row.keys())
self._writer = DictWriter(self._stream, columns, lineterminator="\n")
self._writer.writeheader()
if show_hard_api_key_warning():
warning = ""
if requests_left() == 0:
warning = f"{warning} {ROLLOUT_WARNING_RATE_LIMIT}"
if get_multiples_count(request) < 0:
warning = f"{warning} {ROLLOUT_WARNING_MULTIPLES}"
if requests_left() == 0 or get_multiples_count(request) < 0:
warning = f"{warning} {_ROLLOUT_WARNING_AD_FRAGMENT} {PHASE_1_2_STOPGAP}"
if warning.strip() != "":
self._writer.writerow({columns[0]: warning})

self._writer.writerow(row)

# remove the stream content to print just one line at a time
Expand All @@ -270,18 +226,7 @@ class JSONPrinter(APrinter):
"""

def _begin(self):
r = b"["
if show_hard_api_key_warning():
warning = ""
if requests_left() == 0:
warning = f"{warning} {ROLLOUT_WARNING_RATE_LIMIT}"
if get_multiples_count(request) < 0:
warning = f"{warning} {ROLLOUT_WARNING_MULTIPLES}"
if requests_left() == 0 or get_multiples_count(request) < 0:
warning = f"{warning} {_ROLLOUT_WARNING_AD_FRAGMENT} {PHASE_1_2_STOPGAP}"
if warning.strip() != "":
r = b'["' + bytes(warning, "utf-8") + b'",'
return r
return b"["

def _format_row(self, first: bool, row: Dict):
sep = b"," if not first else b""
Expand All @@ -299,19 +244,6 @@ class JSONLPrinter(APrinter):
def make_response(self, gen):
return Response(gen, mimetype=" text/plain; charset=utf8")

def _begin(self):
if show_hard_api_key_warning():
warning = ""
if requests_left() == 0:
warning = f"{warning} {ROLLOUT_WARNING_RATE_LIMIT}"
if get_multiples_count(request) < 0:
warning = f"{warning} {ROLLOUT_WARNING_MULTIPLES}"
if requests_left() == 0 or get_multiples_count(request) < 0:
warning = f"{warning} {_ROLLOUT_WARNING_AD_FRAGMENT} {PHASE_1_2_STOPGAP}"
if warning.strip() != "":
return bytes(warning, "utf-8") + b"\n"
return None

def _format_row(self, first: bool, row: Dict):
# each line is a JSON file with a new line to separate them
return orjson.dumps(row, option=orjson.OPT_APPEND_NEWLINE)
Expand All @@ -334,4 +266,4 @@ def create_printer(format: str) -> APrinter:
return CSVPrinter()
if format == "jsonl":
return JSONLPrinter()
return ClassicPrinter()
return ClassicPrinter()
38 changes: 0 additions & 38 deletions src/server/_security.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,13 @@
from werkzeug.local import LocalProxy

from ._config import (
API_KEY_REQUIRED_STARTING_AT,
REDIS_HOST,
REDIS_PASSWORD,
API_KEY_REGISTRATION_FORM_LINK_LOCAL,
TEMPORARY_API_KEY,
URL_PREFIX,
)
from .admin.models import User, UserRole

API_KEY_HARD_WARNING = API_KEY_REQUIRED_STARTING_AT - timedelta(days=14)
API_KEY_SOFT_WARNING = API_KEY_HARD_WARNING - timedelta(days=14)

# rollout warning messages
ROLLOUT_WARNING_RATE_LIMIT = "This request exceeded the rate limit on anonymous requests, which will be enforced starting {}.".format(API_KEY_REQUIRED_STARTING_AT)
ROLLOUT_WARNING_MULTIPLES = "This request exceeded the anonymous limit on selected multiples, which will be enforced starting {}.".format(API_KEY_REQUIRED_STARTING_AT)
_ROLLOUT_WARNING_AD_FRAGMENT = "To be exempt from this limit, authenticate your requests with a free API key, now available at {}.".format(API_KEY_REGISTRATION_FORM_LINK_LOCAL)

PHASE_1_2_STOPGAP = (
"A temporary public key `{}` is available for use between now and {} to give you time to register or adapt your requests without this message continuing to break your systems."
).format(TEMPORARY_API_KEY, (API_KEY_REQUIRED_STARTING_AT + timedelta(days=7)))


# steady-state error messages
ERROR_MSG_RATE_LIMIT = "Rate limit exceeded for anonymous queries. To remove this limit, register a free API key at {}".format(API_KEY_REGISTRATION_FORM_LINK_LOCAL)
Expand All @@ -54,30 +40,6 @@ def resolve_auth_token() -> Optional[str]:
return None


def show_no_api_key_warning() -> bool:
# aka "phase 0"
n = date.today()
return not current_user and n < API_KEY_SOFT_WARNING


def show_soft_api_key_warning() -> bool:
# aka "phase 1"
n = date.today()
return not current_user and API_KEY_SOFT_WARNING <= n < API_KEY_HARD_WARNING


def show_hard_api_key_warning() -> bool:
# aka "phase 2"
n = date.today()
return not current_user and API_KEY_HARD_WARNING <= n < API_KEY_REQUIRED_STARTING_AT


def require_api_key() -> bool:
# aka "phase 3"
n = date.today()
return API_KEY_REQUIRED_STARTING_AT <= n


def _get_current_user():
if "user" not in g:
api_key = resolve_auth_token()
Expand Down