diff --git a/tagstudio/src/core/library/alchemy/library.py b/tagstudio/src/core/library/alchemy/library.py index 37b58d050..ea69e78e1 100644 --- a/tagstudio/src/core/library/alchemy/library.py +++ b/tagstudio/src/core/library/alchemy/library.py @@ -592,8 +592,11 @@ def search_library( statement = statement.where(Entry.suffix.in_(extensions)) statement = statement.distinct(Entry.id) + start_time = time.time() query_count = select(func.count()).select_from(statement.alias("entries")) count_all: int = session.execute(query_count).scalar() + end_time = time.time() + logger.info(f"finished counting ({format_timespan(end_time-start_time)})") sort_on: ColumnExpressionArgument = Entry.id match search.sorting_mode: @@ -609,9 +612,14 @@ def search_library( query_full=str(statement.compile(compile_kwargs={"literal_binds": True})), ) + start_time = time.time() + items = session.scalars(statement).fetchall() + end_time = time.time() + logger.info(f"SQL Execution finished ({format_timespan(end_time - start_time)})") + res = SearchResult( total_count=count_all, - items=list(session.scalars(statement)), + items=list(items), ) session.expunge_all() diff --git a/tagstudio/src/core/library/alchemy/visitors.py b/tagstudio/src/core/library/alchemy/visitors.py index f88bc2a77..97d95572e 100644 --- a/tagstudio/src/core/library/alchemy/visitors.py +++ b/tagstudio/src/core/library/alchemy/visitors.py @@ -5,12 +5,11 @@ from typing import TYPE_CHECKING import structlog -from sqlalchemy import and_, distinct, func, or_, select, text +from sqlalchemy import ColumnElement, and_, distinct, func, or_, select, text from sqlalchemy.orm import Session -from sqlalchemy.sql.expression import BinaryExpression, ColumnExpressionArgument from src.core.media_types import FILETYPE_EQUIVALENTS, MediaCategories from src.core.query_lang import BaseVisitor -from src.core.query_lang.ast import AST, ANDList, Constraint, ConstraintType, Not, ORList, Property +from src.core.query_lang.ast import ANDList, Constraint, ConstraintType, Not, ORList, Property from .joins import TagEntry from .models import Entry, Tag, TagAlias @@ -33,7 +32,7 @@ FROM tag_parents tp INNER JOIN ChildTags c ON tp.child_id = c.child_id ) -SELECT * FROM ChildTags; +SELECT child_id FROM ChildTags; """) # noqa: E501 @@ -44,17 +43,17 @@ def get_filetype_equivalency_list(item: str) -> list[str] | set[str]: return [item] -class SQLBoolExpressionBuilder(BaseVisitor[ColumnExpressionArgument]): +class SQLBoolExpressionBuilder(BaseVisitor[ColumnElement[bool]]): def __init__(self, lib: Library) -> None: super().__init__() self.lib = lib - def visit_or_list(self, node: ORList) -> ColumnExpressionArgument: + def visit_or_list(self, node: ORList) -> ColumnElement[bool]: return or_(*[self.visit(element) for element in node.elements]) - def visit_and_list(self, node: ANDList) -> ColumnExpressionArgument: + def visit_and_list(self, node: ANDList) -> ColumnElement[bool]: tag_ids: list[int] = [] - bool_expressions: list[ColumnExpressionArgument] = [] + bool_expressions: list[ColumnElement[bool]] = [] # Search for TagID / unambiguous Tag Constraints and store the respective tag ids separately for term in node.terms: @@ -74,7 +73,7 @@ def visit_and_list(self, node: ANDList) -> ColumnExpressionArgument: tag_ids.append(ids[0]) continue - bool_expressions.append(self.__entry_satisfies_ast(term)) + bool_expressions.append(self.visit(term)) # If there are at least two tag ids use a relational division query # to efficiently check all of them @@ -88,15 +87,15 @@ def visit_and_list(self, node: ANDList) -> ColumnExpressionArgument: return and_(*bool_expressions) - def visit_constraint(self, node: Constraint) -> ColumnExpressionArgument: + def visit_constraint(self, node: Constraint) -> ColumnElement[bool]: """Returns a Boolean Expression that is true, if the Entry satisfies the constraint.""" if len(node.properties) != 0: raise NotImplementedError("Properties are not implemented yet") # TODO TSQLANG if node.type == ConstraintType.Tag: - return Entry.tags.any(Tag.id.in_(self.__get_tag_ids(node.value))) + return self.__entry_matches_tag_ids(self.__get_tag_ids(node.value)) elif node.type == ConstraintType.TagID: - return Entry.tags.any(Tag.id == int(node.value)) + return self.__entry_matches_tag_ids([int(node.value)]) elif node.type == ConstraintType.Path: return Entry.path.op("GLOB")(node.value) elif node.type == ConstraintType.MediaType: @@ -120,8 +119,17 @@ def visit_constraint(self, node: Constraint) -> ColumnExpressionArgument: def visit_property(self, node: Property) -> None: raise NotImplementedError("This should never be reached!") - def visit_not(self, node: Not) -> ColumnExpressionArgument: - return ~self.__entry_satisfies_ast(node.child) + def visit_not(self, node: Not) -> ColumnElement[bool]: + return ~self.visit(node.child) + + def __entry_matches_tag_ids(self, tag_ids: list[int]) -> ColumnElement[bool]: + """Returns a boolean expression that is true if the entry has at least one of the supplied tags.""" # noqa: E501 + return ( + select(1) + .correlate(Entry) + .where(and_(TagEntry.entry_id == Entry.id, TagEntry.tag_id.in_(tag_ids))) + .exists() + ) def __get_tag_ids(self, tag_name: str, include_children: bool = True) -> list[int]: """Given a tag name find the ids of all tags that this name could refer to.""" @@ -146,24 +154,17 @@ def __get_tag_ids(self, tag_name: str, include_children: bool = True) -> list[in outp.extend(list(session.scalars(CHILDREN_QUERY, {"tag_id": tag_id}))) return outp - def __entry_has_all_tags(self, tag_ids: list[int]) -> BinaryExpression[bool]: + def __entry_has_all_tags(self, tag_ids: list[int]) -> ColumnElement[bool]: """Returns Binary Expression that is true if the Entry has all provided tag ids.""" # Relational Division Query return Entry.id.in_( - select(Entry.id) - .outerjoin(TagEntry) + select(TagEntry.entry_id) .where(TagEntry.tag_id.in_(tag_ids)) - .group_by(Entry.id) + .group_by(TagEntry.entry_id) .having(func.count(distinct(TagEntry.tag_id)) == len(tag_ids)) ) - def __entry_satisfies_ast(self, partial_query: AST) -> BinaryExpression[bool]: - """Returns Binary Expression that is true if the Entry satisfies the partial query.""" - return self.__entry_satisfies_expression(self.visit(partial_query)) - - def __entry_satisfies_expression( - self, expr: ColumnExpressionArgument - ) -> BinaryExpression[bool]: + def __entry_satisfies_expression(self, expr: ColumnElement[bool]) -> ColumnElement[bool]: """Returns Binary Expression that is true if the Entry satisfies the column expression. Executed on: Entry ⟕ TagEntry (Entry LEFT OUTER JOIN TagEntry).