From d5011a64f9f98eeb8972e63054e88c7222b2697f Mon Sep 17 00:00:00 2001 From: Christoph Ladurner Date: Thu, 3 Oct 2024 11:58:55 +0200 Subject: [PATCH] global: move to db.session.query syntax * this change is a working solution for sqlalchemy ~= 1.4 but a necessity for >= 2.0 --- invenio_files_rest/models.py | 101 ++++++++++++++++++++++------------- 1 file changed, 63 insertions(+), 38 deletions(-) diff --git a/invenio_files_rest/models.py b/invenio_files_rest/models.py index deca0ac5..1c3dac04 100644 --- a/invenio_files_rest/models.py +++ b/invenio_files_rest/models.py @@ -2,6 +2,7 @@ # # This file is part of Invenio. # Copyright (C) 2015-2019 CERN. +# Copyright (C) 2024 Graz University of Technology. # # Invenio is free software; you can redistribute it and/or modify it # under the terms of the MIT License; see LICENSE file for more details. @@ -116,7 +117,7 @@ def as_object_version(value): return ( value if isinstance(value, ObjectVersion) - else ObjectVersion.query.filter_by(version_id=value).one_or_none() + else db.session.query(ObjectVersion).filter_by(version_id=value).one_or_none() ) @@ -291,22 +292,26 @@ def validate_name(self, key, name): @classmethod def get_by_name(cls, name): """Fetch a specific location object.""" - return cls.query.filter_by( - name=name, - ).one_or_none() + return ( + db.session.query(cls) + .filter_by( + name=name, + ) + .one_or_none() + ) @classmethod def get_default(cls): """Fetch the default location object.""" try: - return cls.query.filter_by(default=True).one_or_none() + return db.session.query(cls).filter_by(default=True).one_or_none() except MultipleResultsFound: return None @classmethod def all(cls): """Return query that fetches all locations.""" - return Location.query.all() + return db.session.query(Location).all() def __repr__(self): """Return representation of location.""" @@ -550,12 +555,14 @@ def get(cls, bucket_id): :param bucket_id: Bucket identifier. :returns: Bucket instance. """ - return cls.query.filter_by(id=bucket_id, deleted=False).one_or_none() + return ( + db.session.query(cls).filter_by(id=bucket_id, deleted=False).one_or_none() + ) @classmethod def all(cls): """Return query of all buckets (excluding deleted).""" - return cls.query.filter_by(deleted=False) + return db.session.query(cls).filter_by(deleted=False) @classmethod def delete(cls, bucket_id): @@ -586,7 +593,7 @@ def remove(self): :returns: ``self``. """ with db.session.begin_nested(): - ObjectVersion.query.filter_by(bucket_id=self.id).delete( + db.session.query(ObjectVersion).filter_by(bucket_id=self.id).delete( synchronize_session=False ) self.query.filter_by(id=self.id).delete(synchronize_session=False) @@ -620,10 +627,14 @@ class BucketTag(db.Model): @classmethod def get(cls, bucket, key): """Get tag object.""" - return cls.query.filter_by( - bucket_id=as_bucket_id(bucket), - key=key, - ).one_or_none() + return ( + db.session.query(cls) + .filter_by( + bucket_id=as_bucket_id(bucket), + key=key, + ) + .one_or_none() + ) @classmethod def create(cls, bucket, key, value): @@ -654,7 +665,7 @@ def get_value(cls, bucket, key): def delete(cls, bucket, key): """Delete a tag.""" with db.session.begin_nested(): - cls.query.filter_by( + db.session.query(cls).filter_by( bucket_id=as_bucket_id(bucket), key=key, ).delete() @@ -723,13 +734,13 @@ def validate_uri(self, key, uri): @classmethod def get(cls, file_id): """Get a file instance.""" - return cls.query.filter_by(id=file_id).one_or_none() + return db.session.query(cls).filter_by(id=file_id).one_or_none() @classmethod def get_by_uri(cls, uri): """Get a file instance by URI.""" assert uri is not None - return cls.query.filter_by(uri=uri).one_or_none() + return db.session.query(cls).filter_by(uri=uri).one_or_none() @classmethod def create(cls): @@ -1260,9 +1271,11 @@ def create( raise BucketLockedError() with db.session.begin_nested(): - latest_obj = cls.query.filter( - cls.bucket == bucket, cls.key == key, cls.is_head.is_(True) - ).one_or_none() + latest_obj = ( + db.session.query(cls) + .filter(cls.bucket == bucket, cls.key == key, cls.is_head.is_(True)) + .one_or_none() + ) if latest_obj is not None: latest_obj.is_head = False db.session.add(latest_obj) @@ -1310,7 +1323,7 @@ def get(cls, bucket, key, version_id=None): filters.append(cls.is_head.is_(True)) filters.append(cls.file_id.isnot(None)) - return cls.query.filter(*filters).one_or_none() + return db.session.query(cls).filter(*filters).one_or_none() @classmethod def get_versions(cls, bucket, key, desc=True): @@ -1328,7 +1341,7 @@ def get_versions(cls, bucket, key, desc=True): order = cls.created.desc() if desc else cls.created.asc() - return cls.query.filter(*filters).order_by(cls.key, order) + return db.session.query(cls).filter(*filters).order_by(cls.key, order) @classmethod def delete(cls, bucket, key): @@ -1370,7 +1383,9 @@ def get_by_bucket(cls, bucket, versions=False, with_deleted=False): if not with_deleted: filters.append(cls.file_id.isnot(None)) - return cls.query.filter(*filters).order_by(cls.key, cls.created.desc()) + return ( + db.session.query(cls).filter(*filters).order_by(cls.key, cls.created.desc()) + ) @classmethod def relink_all(cls, old_file, new_file): @@ -1385,7 +1400,7 @@ def relink_all(cls, old_file, new_file): assert new_file.id with db.session.begin_nested(): - ObjectVersion.query.filter_by(file_id=str(old_file.id)).update( + db.session.query(ObjectVersion).filter_by(file_id=str(old_file.id)).update( {ObjectVersion.file_id: str(new_file.id)} ) @@ -1470,10 +1485,14 @@ def copy(self, object_version=None, key=None): @classmethod def get(cls, object_version, key): """Get the tag object.""" - return cls.query.filter_by( - version_id=as_object_version_id(object_version), - key=key, - ).one_or_none() + return ( + db.session.query(cls) + .filter_by( + version_id=as_object_version_id(object_version), + key=key, + ) + .one_or_none() + ) @classmethod def create(cls, object_version, key, value): @@ -1515,7 +1534,9 @@ def delete(cls, object_version, key=None): Default: delete all tags. """ with db.session.begin_nested(): - q = cls.query.filter_by(version_id=as_object_version_id(object_version)) + q = db.session.query(cls).filter_by( + version_id=as_object_version_id(object_version) + ) if key: q = q.filter_by(key=key) q.delete() @@ -1710,7 +1731,7 @@ def create(cls, bucket, key, size, chunk_size): @classmethod def get(cls, bucket, key, upload_id, with_completed=False): """Fetch a specific multipart object.""" - q = cls.query.filter_by( + q = db.session.query(cls).filter_by( upload_id=upload_id, bucket_id=as_bucket_id(bucket), key=key, @@ -1723,7 +1744,7 @@ def get(cls, bucket, key, upload_id, with_completed=False): @classmethod def query_expired(cls, dt, bucket=None): """Query all uncompleted multipart uploads.""" - q = cls.query.filter(cls.created < dt).filter_by(completed=True) + q = db.session.query(cls).filter(cls.created < dt).filter_by(completed=True) if bucket: q = q.filter(cls.bucket_id == as_bucket_id(bucket)) return q @@ -1731,7 +1752,7 @@ def query_expired(cls, dt, bucket=None): @classmethod def query_by_bucket(cls, bucket): """Query all uncompleted multipart uploads.""" - return cls.query.filter(cls.bucket_id == as_bucket_id(bucket)) + return db.session.query(cls).filter(cls.bucket_id == as_bucket_id(bucket)) class Part(db.Model, Timestamp): @@ -1792,9 +1813,11 @@ def create(cls, mp, part_number, stream=None, **kwargs): @classmethod def get_or_none(cls, mp, part_number): """Get part number.""" - return cls.query.filter_by( - upload_id=mp.upload_id, part_number=part_number - ).one_or_none() + return ( + db.session.query(cls) + .filter_by(upload_id=mp.upload_id, part_number=part_number) + .one_or_none() + ) @classmethod def get_or_create(cls, mp, part_number): @@ -1807,9 +1830,11 @@ def get_or_create(cls, mp, part_number): @classmethod def delete(cls, mp, part_number): """Get part number.""" - return cls.query.filter_by( - upload_id=mp.upload_id, part_number=part_number - ).delete() + return ( + db.session.query(cls) + .filter_by(upload_id=mp.upload_id, part_number=part_number) + .delete() + ) @classmethod def query_by_multipart(cls, multipart): @@ -1822,7 +1847,7 @@ def query_by_multipart(cls, multipart): upload_id = ( multipart.upload_id if isinstance(multipart, MultipartObject) else multipart ) - return cls.query.filter_by(upload_id=upload_id) + return db.session.query(cls).filter_by(upload_id=upload_id) @classmethod def count(cls, mp):