diff --git a/.circleci/config.yml b/.circleci/config.yml index b9c3e88..07284a4 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -24,7 +24,7 @@ jobs: - run: | . env/bin/activate && resources/enable_profiling.py && - pytest --tb=short && + pytest -vvvvv --tb=short && resources/profile_queries.py deploy: docker: diff --git a/file_catalog/argbuilder.py b/file_catalog/argbuilder.py new file mode 100644 index 0000000..2b1eedf --- /dev/null +++ b/file_catalog/argbuilder.py @@ -0,0 +1,80 @@ +"""Builder utility functions for arg/kwargs dicts.""" + +from typing import Any, Dict + +from tornado.escape import json_decode + +# local imports +from file_catalog.mongo import AllKeys + + +def build_limit(kwargs: Dict[str, Any], config: Dict[str, Any]) -> None: + """Build the `"limit"` argument.""" + if "limit" in kwargs: + kwargs["limit"] = int(kwargs["limit"]) + if kwargs["limit"] < 1: + raise Exception("limit is not positive") + + # check with config + if kwargs["limit"] > config["FC_QUERY_FILE_LIST_LIMIT"]: + kwargs["limit"] = config["FC_QUERY_FILE_LIST_LIMIT"] + else: + # if no limit has been defined, set max limit + kwargs["limit"] = config["FC_QUERY_FILE_LIST_LIMIT"] + + +def build_start(kwargs: Dict[str, Any]) -> None: + """Build the `"start"` argument.""" + if "start" in kwargs: + kwargs["start"] = int(kwargs["start"]) + if kwargs["start"] < 0: + raise Exception("start is negative") + + +def build_files_query(kwargs: Dict[str, Any]) -> None: + """Build `"query"` dict with formatted/fully-named arguments. + + Pop corresponding shortcut-keys from `kwargs`. + """ + if "query" in kwargs: + # keep whatever was already in here, then add to it + if isinstance(kwargs["query"], (str, bytes)): + query = json_decode(kwargs.pop("query")) + else: + query = kwargs.pop("query") + else: + query = {} + + if "locations.archive" not in query: + query["locations.archive"] = None + + # shortcut query params + if "logical_name" in kwargs: + query["logical_name"] = kwargs.pop("logical_name") + if "run_number" in kwargs: + query["run.run_number"] = kwargs.pop("run_number") + if "dataset" in kwargs: + query["iceprod.dataset"] = kwargs.pop("dataset") + if "event_id" in kwargs: + e = kwargs.pop("event_id") + query["run.first_event"] = {"$lte": e} + query["run.last_event"] = {"$gte": e} + if "processing_level" in kwargs: + query["processing_level"] = kwargs.pop("processing_level") + if "season" in kwargs: + query["offline_processing_metadata.season"] = kwargs.pop("season") + + kwargs["query"] = query + + +def build_keys(kwargs: Dict[str, Any]) -> None: + """Build `"keys"` list, potentially using `"all-keys"`. + + Pop `"all-keys"`. + """ + use_all_keys = kwargs.pop("all-keys", None) in ["True", "true", 1] + + if use_all_keys: + kwargs["keys"] = AllKeys() + elif "keys" in kwargs: + kwargs["keys"] = kwargs["keys"].split("|") diff --git a/file_catalog/mongo.py b/file_catalog/mongo.py index 5faa83b..1d63d07 100644 --- a/file_catalog/mongo.py +++ b/file_catalog/mongo.py @@ -1,26 +1,28 @@ +"""File Catalog MongoDB Interface.""" + + from __future__ import absolute_import, division, print_function import datetime import logging from concurrent.futures import ThreadPoolExecutor +from typing import Any, cast, Dict, List, Optional, Union -import pymongo -from bson.objectid import ObjectId +import pymongo # type: ignore[import] from pymongo import MongoClient -from pymongo.errors import BulkWriteError from tornado.concurrent import run_on_executor -try: - from collections.abc import Iterable -except ImportError: - from collections import Iterable +logger = logging.getLogger("mongo") +class AllKeys: # pylint: disable=R0903 + """Include all keys in MongoDB find*() methods.""" -logger = logging.getLogger('mongo') class Mongo(object): - """A ThreadPoolExecutor-based MongoDB client""" + """A ThreadPoolExecutor-based MongoDB client.""" + + # fmt:off def __init__(self, host=None, port=None, authSource=None, username=None, password=None, uri=None): if uri: @@ -40,12 +42,12 @@ def __init__(self, host=None, port=None, authSource=None, username=None, passwor self.client.files.create_index([('locations.site',pymongo.DESCENDING),('locations.path',pymongo.DESCENDING)], background=True) self.client.files.create_index('locations.archive', background=True) self.client.files.create_index('create_date', background=True) - + # all .i3 files self.client.files.create_index('content_status', sparse=True, background=True) self.client.files.create_index('processing_level', sparse=True, background=True) self.client.files.create_index('data_type', sparse=True, background=True) - + # data_type=real files self.client.files.create_index('run_number', sparse=True, background=True) self.client.files.create_index('start_datetime', sparse=True, background=True) @@ -53,7 +55,7 @@ def __init__(self, host=None, port=None, authSource=None, username=None, passwor self.client.files.create_index('offline_processing_metadata.first_event', sparse=True, background=True) self.client.files.create_index('offline_processing_metadata.last_event', sparse=True, background=True) self.client.files.create_index('offline_processing_metadata.season', sparse=True, background=True) - + # data_type=simulation files self.client.files.create_index('iceprod.dataset', sparse=True, background=True) @@ -67,22 +69,41 @@ def __init__(self, host=None, port=None, authSource=None, username=None, passwor self.executor = ThreadPoolExecutor(max_workers=10) logger.info('done setting up Mongo') - - @run_on_executor - def find_files(self, query={}, keys=None, limit=None, start=0): - if keys and isinstance(keys,Iterable) and not isinstance(keys,str): - projection = {k:True for k in keys} + # fmt:on + + @staticmethod + def _get_projection( + keys: Optional[Union[List[str], AllKeys]] = None, + default: Optional[Dict[str, bool]] = None, + ) -> Dict[str, bool]: + projection = {"_id": False} + + if not keys: + if default: # use default keys if they're available + projection.update(default) + elif isinstance(keys, AllKeys): + pass # only use "_id" constraint in projection + elif isinstance(keys, list): + projection.update({k: True for k in keys}) else: - projection = {'uuid':True, 'logical_name':True} - projection['_id'] = False + raise TypeError( + f"`keys` argument ({keys}) is not NoneType, list, or AllKeys" + ) - result = self.client.files.find(query, projection) - ret = [] + return projection - # `limit` and `skip` are ignored by __getitem__: - # http://api.mongodb.com/python/current/api/pymongo/cursor.html#pymongo.cursor.Cursor.__getitem__ - # - # Therefore, implement it manually: + @staticmethod + def _limit_result_list( + result: List[Dict[str, Any]], limit: Optional[int] = None, start: int = 0, + ) -> List[Dict[str, Any]]: + """Get sublist of `results` using `limit` and `start`. + + `limit` and `skip` are ignored by __getitem__: + http://api.mongodb.com/python/current/api/pymongo/cursor.html#pymongo.cursor.Cursor.__getitem__ + + Therefore, implement it manually. + """ + ret = [] end = None if limit is not None: @@ -90,137 +111,223 @@ def find_files(self, query={}, keys=None, limit=None, start=0): for row in result[start:end]: ret.append(row) + return ret @run_on_executor - def count_files(self, query={}, **kwargs): - ret = self.client.files.count_documents(query) + def find_files( + self, + query: Optional[Dict[str, Any]] = None, + keys: Optional[Union[List[str], AllKeys]] = None, + limit: Optional[int] = None, + start: int = 0, + ) -> List[Dict[str, Any]]: + """Find files. + + Optionally, apply keyword arguments. "_id" is always excluded. + + Decorators: + run_on_executor + + Keyword Arguments: + query -- MongoDB query + keys -- fields to include in MongoDB projection + limit -- max count of files returned + start -- starting index + + Returns: + List of MongoDB files + """ + projection = Mongo._get_projection( + keys, default={"uuid": True, "logical_name": True} + ) + result = self.client.files.find(query, projection) + ret = Mongo._limit_result_list(result, limit, start) + return ret @run_on_executor - def create_file(self, metadata): + def count_files( # pylint: disable=W0613 + self, query: Optional[Dict[str, Any]] = None, **kwargs: Any, + ) -> int: + """Get count of files matching query.""" + if not query: + query = {} + + ret = self.client.files.count_documents(query) + + return cast(int, ret) + + @run_on_executor + def create_file(self, metadata: Dict[str, Any]) -> str: + """Insert file metadata. + + Return uuid. + """ result = self.client.files.insert_one(metadata) + if (not result) or (not result.inserted_id): - logger.warn('did not insert file') - raise Exception('did not insert new file') - return metadata['uuid'] + msg = "did not insert new file" + logger.warning(msg) + raise Exception(msg) + + return cast(str, metadata["uuid"]) @run_on_executor - def get_file(self, filters): - return self.client.files.find_one(filters, {'_id':False}) + def get_file(self, filters: Dict[str, Any]) -> Dict[str, Any]: + """Get file matching filters.""" + file = self.client.files.find_one(filters, {"_id": False}) + return cast(Dict[str, Any], file) @run_on_executor - def update_file(self, uuid, metadata): - result = self.client.files.update_one({'uuid': uuid}, - {'$set': metadata}) + def update_file(self, uuid: str, metadata: Dict[str, Any]) -> None: + """Update file.""" + result = self.client.files.update_one({"uuid": uuid}, {"$set": metadata}) if result.modified_count is None: - logger.warn('Cannot determine if document has been modified since `result.modified_count` has the value `None`. `result.matched_count` is %s' % result.matched_count) + logger.warning( + "Cannot determine if document has been modified since `result.modified_count` has the value `None`. `result.matched_count` is %s", + result.matched_count, + ) elif result.modified_count != 1: - logger.warn('updated %s files with id %r', - result.modified_count, uuid) - raise Exception('did not update') + msg = f"updated {result.modified_count} files with id {uuid}" + logger.warning(msg) + raise Exception(msg) @run_on_executor - def replace_file(self, metadata): - uuid = metadata['uuid'] + def replace_file(self, metadata: Dict[str, Any]) -> None: + """Replace file. + + Metadata must include 'uuid'. + """ + uuid = metadata["uuid"] - result = self.client.files.replace_one({'uuid': uuid}, - metadata) + result = self.client.files.replace_one({"uuid": uuid}, metadata) if result.modified_count is None: - logger.warn('Cannot determine if document has been modified since `result.modified_count` has the value `None`. `result.matched_count` is %s' % result.matched_count) + logger.warning( + "Cannot determine if document has been modified since `result.modified_count` has the value `None`. `result.matched_count` is %s", + result.matched_count, + ) elif result.modified_count != 1: - logger.warn('updated %s files with id %r', - result.modified_count, uuid) - raise Exception('did not update') + msg = f"updated {result.modified_count} files with id {uuid}" + logger.warning(msg) + raise Exception(msg) @run_on_executor - def delete_file(self, filters): + def delete_file(self, filters: Dict[str, Any]) -> None: + """Delete file matching filters.""" result = self.client.files.delete_one(filters) if result.deleted_count != 1: - logger.warn('deleted %d files with filter %r', - result.deleted_count, filter) - raise Exception('did not delete') + msg = f"deleted {result.deleted_count} files with filter {filters}" + logger.warning(msg) + raise Exception(msg) @run_on_executor - def find_collections(self, keys=None, limit=None, start=0): - if keys and isinstance(keys,Iterable) and not isinstance(keys,str): - projection = {k:True for k in keys} - else: - projection = {} # show all fields - projection['_id'] = False - + def find_collections( + self, + keys: Optional[Union[List[str], AllKeys]] = None, + limit: Optional[int] = None, + start: int = 0, + ) -> List[Dict[str, Any]]: + """Find all collections. + + Optionally, apply keyword arguments. "_id" is always excluded. + + Decorators: + run_on_executor + + Keyword Arguments: + keys -- fields to include in MongoDB projection + limit -- max count of collections returned + start -- starting index + + Returns: + List of MongoDB collections + """ + projection = Mongo._get_projection(keys) # show all fields by default result = self.client.collections.find({}, projection) - ret = [] + ret = Mongo._limit_result_list(result, limit, start) - # `limit` and `skip` are ignored by __getitem__: - # http://api.mongodb.com/python/current/api/pymongo/cursor.html#pymongo.cursor.Cursor.__getitem__ - # - # Therefore, implement it manually: - end = None - - if limit is not None: - end = start + limit - - for row in result[start:end]: - ret.append(row) return ret @run_on_executor - def create_collection(self, metadata): + def create_collection(self, metadata: Dict[str, Any]) -> str: + """Create collection, insert metadata. + + Return uuid. + """ result = self.client.collections.insert_one(metadata) + if (not result) or (not result.inserted_id): - logger.warn('did not insert collection') - raise Exception('did not insert new collection') - return metadata['uuid'] + msg = "did not insert new collection" + logger.warning(msg) + raise Exception(msg) - @run_on_executor - def get_collection(self, filters): - return self.client.collections.find_one(filters, {'_id':False}) + return cast(str, metadata["uuid"]) @run_on_executor - def find_snapshots(self, query={}, keys=None, limit=None, start=0): - if keys and isinstance(keys,Iterable) and not isinstance(keys,str): - projection = {k:True for k in keys} - else: - projection = {} # show all fields - projection['_id'] = False + def get_collection(self, filters: Dict[str, Any]) -> Dict[str, Any]: + """Get collection matching filters.""" + collection = self.client.collections.find_one(filters, {"_id": False}) + return cast(Dict[str, Any], collection) + @run_on_executor + def find_snapshots( + self, + query: Optional[Dict[str, Any]] = None, + keys: Optional[Union[List[str], AllKeys]] = None, + limit: Optional[int] = None, + start: int = 0, + ) -> List[Dict[str, Any]]: + """Find snapshots. + + Optionally, apply keyword arguments. "_id" is always excluded. + + Decorators: + run_on_executor + + Keyword Arguments: + query -- MongoDB query + keys -- fields to include in MongoDB projection + limit -- max count of snapshots returned + start -- starting index + + Returns: + List of MongoDB snapshots + """ + projection = Mongo._get_projection(keys) # show all fields by default result = self.client.snapshots.find(query, projection) - ret = [] - - # `limit` and `skip` are ignored by __getitem__: - # http://api.mongodb.com/python/current/api/pymongo/cursor.html#pymongo.cursor.Cursor.__getitem__ - # - # Therefore, implement it manually: - end = None - - if limit is not None: - end = start + limit + ret = Mongo._limit_result_list(result, limit, start) - for row in result[start:end]: - ret.append(row) return ret @run_on_executor - def create_snapshot(self, metadata): + def create_snapshot(self, metadata: Dict[str, Any]) -> str: + """Insert metadata into 'snapshots' collection.""" result = self.client.snapshots.insert_one(metadata) + if (not result) or (not result.inserted_id): - logger.warn('did not insert snapshot') - raise Exception('did not insert new snapshot') - return metadata['uuid'] + msg = "did not insert new snapshot" + logger.warning(msg) + raise Exception(msg) + + return cast(str, metadata["uuid"]) @run_on_executor - def get_snapshot(self, filters): - return self.client.snapshots.find_one(filters, {'_id':False}) + def get_snapshot(self, filters: Optional[Dict[str, Any]]) -> Dict[str, Any]: + """Find snapshot, optionally filtered.""" + snapshot = self.client.snapshots.find_one(filters, {"_id": False}) + return cast(Dict[str, Any], snapshot) @run_on_executor - def append_distinct_elements_to_file(self, uuid, metadata): + def append_distinct_elements_to_file( + self, uuid: str, metadata: Dict[str, Any] + ) -> None: """Append distinct elements to arrays within a file document.""" # build the query to update the file document - update_query = {"$addToSet": {}} + update_query: Dict[str, Any] = {"$addToSet": {}} for key in metadata: if isinstance(metadata[key], list): update_query["$addToSet"][key] = {"$each": metadata[key]} @@ -229,12 +336,15 @@ def append_distinct_elements_to_file(self, uuid, metadata): # update the file document update_query["$set"] = {"meta_modify_date": str(datetime.datetime.utcnow())} - result = self.client.files.update_one({'uuid': uuid}, update_query) + result = self.client.files.update_one({"uuid": uuid}, update_query) # log and/or throw if the update results are surprising if result.modified_count is None: - logger.warn('Cannot determine if document has been modified since `result.modified_count` has the value `None`. `result.matched_count` is %s' % result.matched_count) + logger.warning( + "Cannot determine if document has been modified since `result.modified_count` has the value `None`. `result.matched_count` is %s", + result.matched_count, + ) elif result.modified_count != 1: - logger.warn('updated %s files with id %r', - result.modified_count, uuid) - raise Exception('did not update') + msg = f"updated {result.modified_count} files with id {uuid}" + logger.warning(msg) + raise Exception(msg) diff --git a/file_catalog/server.py b/file_catalog/server.py index 00936cd..2c7ea59 100644 --- a/file_catalog/server.py +++ b/file_catalog/server.py @@ -1,3 +1,8 @@ +"""File Catalog REST Server Interface.""" + +# fmt: off +# isort:skip_file + from __future__ import absolute_import, division, print_function import copy @@ -27,15 +32,17 @@ from tornado.httpclient import HTTPError from rest_tools.server import Auth +# local imports import file_catalog from file_catalog.mongo import Mongo -from file_catalog import urlargparse +from file_catalog import urlargparse, argbuilder from file_catalog.validation import Validation logger = logging.getLogger('server') + def get_pkgdata_filename(package, resource): - """Get a filename for a resource bundled within the package""" + """Get a filename for a resource bundled within the package.""" loader = get_loader(package) if loader is None or not hasattr(loader, 'get_data'): return None @@ -50,8 +57,9 @@ def get_pkgdata_filename(package, resource): parts.insert(0, os.path.dirname(mod.__file__)) return os.path.join(*parts) + def tornado_logger(handler): - """Log levels based on status code""" + """Log levels based on status code.""" if handler.get_status() < 400: log_method = logger.debug elif handler.get_status() < 500: @@ -62,9 +70,10 @@ def tornado_logger(handler): log_method("%d %s %.2fms", handler.get_status(), handler._request_summary(), request_time) + def sort_dict(d): - """ - Creates an OrderedDict by taking the `dict` named `d` and orders its keys. + """Creates an OrderedDict by taking `dict` named `d` and orders its keys. + If a key contains a `dict` it will call this function recursively. """ @@ -77,11 +86,13 @@ def sort_dict(d): return od + def set_last_modification_date(d): d['meta_modify_date'] = str(datetime.datetime.utcnow()) + class Server(object): - """A file_catalog server instance""" + """A file_catalog server instance.""" def __init__(self, config, port=8888, debug=False, db_host='localhost', db_port=27017, @@ -150,8 +161,9 @@ def __init__(self, config, port=8888, debug=False, def run(self): tornado.ioloop.IOLoop.current().start() + class MainHandler(tornado.web.RequestHandler): - """Main HTML handler""" + """Main HTML handler.""" def initialize(self, base_url='/', debug=False, config=None): self.base_url = base_url self.debug = debug @@ -204,8 +216,9 @@ def write_error(self,status_code=500,**kwargs): self.write('
'.join(kwargs['message'].split('\n'))) self.finish() + def catch_error(method): - """Decorator to catch and handle errors on api handlers""" + """Decorator to catch and handle errors on api handlers.""" @wraps(method) def wrapper(self, *args, **kwargs): try: @@ -218,8 +231,9 @@ def wrapper(self, *args, **kwargs): self.send_error(**kwargs) return wrapper + class LoginHandler(MainHandler): - """Login HTML handler""" + """Login HTML handler.""" @catch_error def get(self): if not self.get_argument('access', False): @@ -240,7 +254,7 @@ def get(self): class AccountHandler(MainHandler): - """Account HTML handler""" + """Account HTML handler.""" @catch_error def get(self): if not self.get_argument('access', False): @@ -255,8 +269,9 @@ def get(self): refresh = self.get_argument('refresh') self.render('account.html', authkey=refresh, tempkey=access) + def validate_auth(method): - """Decorator to check auth key on api handlers""" + """Decorator to check auth key on api handlers.""" @wraps(method) def wrapper(self, *args, **kwargs): if not self.auth: # skip auth if not present @@ -278,8 +293,9 @@ def wrapper(self, *args, **kwargs): return method(self, *args, **kwargs) return wrapper + class APIHandler(tornado.web.RequestHandler): - """Base class for API handlers""" + """Base class for API handlers.""" def initialize(self, config, db=None, base_url='/', debug=False, rate_limit=10): self.db = db self.base_url = base_url @@ -336,6 +352,7 @@ def write_error(self,status_code=500,**kwargs): self.write(kwargs) self.finish() + class HATEOASHandler(APIHandler): def initialize(self, **kwargs): super(HATEOASHandler, self).initialize(**kwargs) @@ -353,41 +370,6 @@ def get(self): self.write(self.data) -def build_files_query(kwargs: dict) -> dict: - """Return dict with formatted/fully-named arguments for querying files. - - Pop corresponding keys from `kwargs`. - """ - if 'query' in kwargs: - # keep whatever was already in here, then add to it - if isinstance(kwargs['query'], (str, bytes)): - query = json_decode(kwargs.pop('query')) - else: - query = kwargs.pop('query') - else: - query = {} - - if 'locations.archive' not in query: - query['locations.archive'] = None - - # shortcut query params - if 'logical_name' in kwargs: - query['logical_name'] = kwargs.pop('logical_name') - if 'run_number' in kwargs: - query['run.run_number'] = kwargs.pop('run_number') - if 'dataset' in kwargs: - query['iceprod.dataset'] = kwargs.pop('dataset') - if 'event_id' in kwargs: - e = kwargs.pop('event_id') - query['run.first_event'] = {'$lte': e} - query['run.last_event'] = {'$gte': e} - if 'processing_level' in kwargs: - query['processing_level'] = kwargs.pop('processing_level') - if 'season' in kwargs: - query['offline_processing_metadata.season'] = kwargs.pop('season') - - return query - class FilesHandler(APIHandler): def initialize(self, **kwargs): super(FilesHandler, self).initialize(**kwargs) @@ -400,32 +382,17 @@ def initialize(self, **kwargs): def get(self): try: kwargs = urlargparse.parse(self.request.query) - if 'limit' in kwargs: - kwargs['limit'] = int(kwargs['limit']) - if kwargs['limit'] < 1: - raise Exception('limit is not positive') - - # check with config - if kwargs['limit'] > self.config['FC_QUERY_FILE_LIST_LIMIT']: - kwargs['limit'] = self.config['FC_QUERY_FILE_LIST_LIMIT'] - else: - # if no limit has been defined, set max limit - kwargs['limit'] = self.config['FC_QUERY_FILE_LIST_LIMIT'] - - if 'start' in kwargs: - kwargs['start'] = int(kwargs['start']) - if kwargs['start'] < 0: - raise Exception('start is negative') - - kwargs['query'] = build_files_query(kwargs) - - if 'keys' in kwargs: - kwargs['keys'] = kwargs['keys'].split('|') + argbuilder.build_limit(kwargs, self.config) + argbuilder.build_start(kwargs) + argbuilder.build_files_query(kwargs) + argbuilder.build_keys(kwargs) except: logging.warn('query parameter error', exc_info=True) self.send_error(400, message='invalid query parameters') return + files = yield self.db.find_files(**kwargs) + self.write({ '_links':{ 'self': {'href': self.files_url}, @@ -502,6 +469,7 @@ def post(self): 'file': os.path.join(self.files_url, ret), }) + class FilesCountHandler(APIHandler): def initialize(self, **kwargs): super(FilesCountHandler, self).initialize(**kwargs) @@ -514,12 +482,14 @@ def initialize(self, **kwargs): def get(self): try: kwargs = urlargparse.parse(self.request.query) - build_files_query(kwargs) + argbuilder.build_files_query(kwargs) except: logging.warn('query parameter error', exc_info=True) self.send_error(400, message='invalid query parameters') return + files = yield self.db.count_files(**kwargs) + self.write({ '_links':{ 'self': {'href': self.files_url}, @@ -528,6 +498,7 @@ def get(self): 'files': files, }) + class SingleFileHandler(APIHandler): def initialize(self, **kwargs): super(SingleFileHandler, self).initialize(**kwargs) @@ -695,7 +666,7 @@ class SingleFileLocationsHandler(APIHandler): """Initialize a handler for adding new locations to an existing record.""" def initialize(self, **kwargs): - """Initialize a handler for adding new locations to an existing record.""" + """Initialize a handler for adding new locations to existing record.""" super(SingleFileLocationsHandler, self).initialize(**kwargs) self.files_url = os.path.join(self.base_url, 'files') @@ -774,6 +745,7 @@ def initialize(self, **kwargs): self.collections_url = os.path.join(self.base_url,'collections') self.snapshots_url = os.path.join(self.base_url,'snapshots') + class CollectionsHandler(CollectionBaseHandler): @validate_auth @catch_error @@ -781,30 +753,16 @@ class CollectionsHandler(CollectionBaseHandler): def get(self): try: kwargs = urlargparse.parse(self.request.query) - if 'limit' in kwargs: - kwargs['limit'] = int(kwargs['limit']) - if kwargs['limit'] < 1: - raise Exception('limit is not positive') - - # check with config - if kwargs['limit'] > self.config['FC_QUERY_FILE_LIST_LIMIT']: - kwargs['limit'] = self.config['FC_QUERY_FILE_LIST_LIMIT'] - else: - # if no limit has been defined, set max limit - kwargs['limit'] = self.config['FC_QUERY_FILE_LIST_LIMIT'] - - if 'start' in kwargs: - kwargs['start'] = int(kwargs['start']) - if kwargs['start'] < 0: - raise Exception('start is negative') - - if 'keys' in kwargs: - kwargs['keys'] = kwargs['keys'].split('|') + argbuilder.build_limit(kwargs, self.config) + argbuilder.build_start(kwargs) + argbuilder.build_keys(kwargs) except: logging.warn('query parameter error', exc_info=True) self.send_error(400, message='invalid query parameters') return + collections = yield self.db.find_collections(**kwargs) + self.write({ '_links':{ 'self': {'href': self.collections_url}, @@ -819,14 +777,13 @@ def get(self): def post(self): metadata = json_decode(self.request.body) - query = {} try: - query = build_files_query(metadata) + argbuilder.build_files_query(metadata) + metadata['query'] = json_encode(metadata['query']) except: logging.warn('query parameter error', exc_info=True) self.send_error(400, message='invalid query parameters') return - metadata['query'] = json_encode(query) if 'collection_name' not in metadata: self.send_error(400, message='missing collection_name') @@ -860,6 +817,7 @@ def post(self): 'collection': os.path.join(self.collections_url, ret), }) + class SingleCollectionHandler(CollectionBaseHandler): @validate_auth @catch_error @@ -879,6 +837,7 @@ def get(self, uid): else: self.send_error(404, message='collection not found') + class SingleCollectionFilesHandler(CollectionBaseHandler): @validate_auth @catch_error @@ -891,32 +850,17 @@ def get(self, uid): if ret: try: kwargs = urlargparse.parse(self.request.query) - if 'limit' in kwargs: - kwargs['limit'] = int(kwargs['limit']) - if kwargs['limit'] < 1: - raise Exception('limit is not positive') - - # check with config - if kwargs['limit'] > self.config['FC_QUERY_FILE_LIST_LIMIT']: - kwargs['limit'] = self.config['FC_QUERY_FILE_LIST_LIMIT'] - else: - # if no limit has been defined, set max limit - kwargs['limit'] = self.config['FC_QUERY_FILE_LIST_LIMIT'] - - if 'start' in kwargs: - kwargs['start'] = int(kwargs['start']) - if kwargs['start'] < 0: - raise Exception('start is negative') - + argbuilder.build_limit(kwargs, self.config) + argbuilder.build_start(kwargs) kwargs['query'] = json_decode(ret['query']) - - if 'keys' in kwargs: - kwargs['keys'] = kwargs['keys'].split('|') + argbuilder.build_keys(kwargs) except: logging.warn('query parameter error', exc_info=True) self.send_error(400, message='invalid query parameters') return + files = yield self.db.find_files(**kwargs) + self.write({ '_links':{ 'self': {'href': os.path.join(self.collections_url,uid,'files')}, @@ -927,6 +871,7 @@ def get(self, uid): else: self.send_error(404, message='collection not found') + class SingleCollectionSnapshotsHandler(CollectionBaseHandler): @validate_auth @catch_error @@ -941,31 +886,17 @@ def get(self, uid): try: kwargs = urlargparse.parse(self.request.query) - if 'limit' in kwargs: - kwargs['limit'] = int(kwargs['limit']) - if kwargs['limit'] < 1: - raise Exception('limit is not positive') - - # check with config - if kwargs['limit'] > self.config['FC_QUERY_FILE_LIST_LIMIT']: - kwargs['limit'] = self.config['FC_QUERY_FILE_LIST_LIMIT'] - else: - # if no limit has been defined, set max limit - kwargs['limit'] = self.config['FC_QUERY_FILE_LIST_LIMIT'] - - if 'start' in kwargs: - kwargs['start'] = int(kwargs['start']) - if kwargs['start'] < 0: - raise Exception('start is negative') - - if 'keys' in kwargs: - kwargs['keys'] = kwargs['keys'].split('|') + argbuilder.build_limit(kwargs, self.config) + argbuilder.build_start(kwargs) + argbuilder.build_keys(kwargs) + kwargs['query'] = {'collection_id': ret['uuid']} except: logging.warn('query parameter error', exc_info=True) self.send_error(400, message='invalid query parameters') return - kwargs['query'] = {'collection_id': ret['uuid']} + snapshots = yield self.db.find_snapshots(**kwargs) + self.write({ '_links':{ 'self': {'href': os.path.join(self.collections_url,uid,'snapshots')}, @@ -1028,6 +959,7 @@ def post(self, uid): 'snapshot': os.path.join(self.snapshots_url, ret), }) + class SingleSnapshotHandler(CollectionBaseHandler): @validate_auth @catch_error @@ -1045,6 +977,7 @@ def get(self, uid): else: self.send_error(404, message='snapshot not found') + class SingleSnapshotFilesHandler(CollectionBaseHandler): @validate_auth @catch_error @@ -1055,33 +988,18 @@ def get(self, uid): if ret: try: kwargs = urlargparse.parse(self.request.query) - if 'limit' in kwargs: - kwargs['limit'] = int(kwargs['limit']) - if kwargs['limit'] < 1: - raise Exception('limit is not positive') - - # check with config - if kwargs['limit'] > self.config['FC_QUERY_FILE_LIST_LIMIT']: - kwargs['limit'] = self.config['FC_QUERY_FILE_LIST_LIMIT'] - else: - # if no limit has been defined, set max limit - kwargs['limit'] = self.config['FC_QUERY_FILE_LIST_LIMIT'] - - if 'start' in kwargs: - kwargs['start'] = int(kwargs['start']) - if kwargs['start'] < 0: - raise Exception('start is negative') - + argbuilder.build_limit(kwargs, self.config) + argbuilder.build_start(kwargs) kwargs['query'] = {'uuid':{'$in':ret['files']}} logger.warning('getting files: %r', kwargs['query']) - - if 'keys' in kwargs: - kwargs['keys'] = kwargs['keys'].split('|') + argbuilder.build_keys(kwargs) except: logging.warn('query parameter error', exc_info=True) self.send_error(400, message='invalid query parameters') return + files = yield self.db.find_files(**kwargs) + self.write({ '_links':{ 'self': {'href': os.path.join(self.snapshots_url,uid,'files')}, diff --git a/tests/test_files.py b/tests/test_files.py index 0f631d5..3fd4c1a 100644 --- a/tests/test_files.py +++ b/tests/test_files.py @@ -1,14 +1,19 @@ +# fmt:off + from __future__ import absolute_import, division, print_function +import hashlib import os import unittest -import hashlib -from tornado.escape import json_encode,json_decode +from tornado.escape import json_decode, json_encode + +# local imports from rest_tools.client import RestClient from .test_server import TestServerAPI + def hex(data): if isinstance(data, str): data = data.encode('utf-8') @@ -68,6 +73,76 @@ def test_11_files_count(self): self.assertIn('files', data) self.assertEqual(data['files'], 1) + def test_12_files_keys(self): + """Test the 'keys' and all-keys' arguments.""" + self.start_server() + token = self.get_token() + r = RestClient(self.address, token, timeout=1, retries=1) + + metadata = { + "logical_name": "blah", + "checksum": {"sha512": hex("foo bar")}, + "file_size": 1, + "locations": [{"site": "test", "path": "blah.dat"}], + "extra": "foo", + "supplemental": ["green", "eggs", "ham"], + } + data = r.request_seq("POST", "/api/files", metadata) + self.assertIn("_links", data) + self.assertIn("self", data["_links"]) + self.assertIn("file", data) + assert "extra" not in data + assert "supplemental" not in data + url = data["file"] + uid = url.split("/")[-1] + + # w/o all-keys + data = r.request_seq("GET", "/api/files") + assert set(data["files"][0].keys()) == {"logical_name", "uuid"} + + # w/ all-keys + args = {"all-keys": True} + data = r.request_seq("GET", "/api/files", args) + assert set(data["files"][0].keys()) == { + "logical_name", + "uuid", + "checksum", + "file_size", + "locations", + "extra", + "supplemental", + "meta_modify_date" + } + + # w/ all-keys = False + args = {"all-keys": False} + data = r.request_seq("GET", "/api/files", args) + assert set(data["files"][0].keys()) == {"logical_name", "uuid"} + + # w/ all-keys & keys + args = {"all-keys": True, "keys": "checksum|file_size"} + data = r.request_seq("GET", "/api/files", args) + assert set(data["files"][0].keys()) == { + "logical_name", + "uuid", + "checksum", + "file_size", + "locations", + "extra", + "supplemental", + "meta_modify_date" + } + + # w/ all-keys = False & keys + args = {"all-keys": False, "keys": "checksum|file_size"} + data = r.request_seq("GET", "/api/files", args) + assert set(data["files"][0].keys()) == {"checksum", "file_size"} + + # w/ just keys + args = {"keys": "checksum|file_size"} + data = r.request_seq("GET", "/api/files", args) + assert set(data["files"][0].keys()) == {"checksum", "file_size"} + def test_15_files_auth(self): self.start_server(config_override={'SECRET':'secret'}) token = self.get_token()