From 2f459c7f55b7c2e943a5fc05b11a7bdb6fac5844 Mon Sep 17 00:00:00 2001 From: Rudolf Cardinal Date: Thu, 9 Jan 2025 11:08:13 +0000 Subject: [PATCH] Try to make research_database_info less singleton-like; fix test suite for RNC environment --- .../tests/researcher_report_tests.py | 18 +- crate_anon/crateweb/anonymise_api/tests.py | 2 + crate_anon/crateweb/config/urls.py | 3 +- crate_anon/crateweb/consent/models.py | 3 +- crate_anon/crateweb/consent/views.py | 4 +- crate_anon/crateweb/research/archive_func.py | 3 +- crate_anon/crateweb/research/models.py | 37 +-- .../crateweb/research/research_db_info.py | 234 ++++++++++++------ crate_anon/crateweb/research/sql_writer.py | 3 +- .../research/tests/research_db_info_tests.py | 167 +++++++++++++ crate_anon/crateweb/research/views.py | 28 ++- .../pipeline_test.sh | 6 + 12 files changed, 405 insertions(+), 103 deletions(-) create mode 100644 crate_anon/crateweb/research/tests/research_db_info_tests.py diff --git a/crate_anon/anonymise/tests/researcher_report_tests.py b/crate_anon/anonymise/tests/researcher_report_tests.py index 05f1c484..3aa9726b 100644 --- a/crate_anon/anonymise/tests/researcher_report_tests.py +++ b/crate_anon/anonymise/tests/researcher_report_tests.py @@ -27,15 +27,15 @@ """ +import os.path import random -from tempfile import NamedTemporaryFile +from tempfile import TemporaryDirectory from typing import List, TYPE_CHECKING from unittest import mock import factory from pypdf import PdfReader import pytest - from sqlalchemy import ( Column, DateTime, @@ -43,10 +43,8 @@ Integer, Text, ) - from sqlalchemy.orm import relationship - from crate_anon.anonymise.researcher_report import ( mk_researcher_report_pdf, ResearcherReportConfig, @@ -160,6 +158,8 @@ def setUp(self) -> None: ) self.anon_dbsession.commit() + self.tempdir = TemporaryDirectory() + @pytest.mark.usefixtures("django_test_settings") def test_report_has_pages_for_each_table(self) -> None: def index_of_list_substring(items: List[str], substr: str) -> int: @@ -171,7 +171,9 @@ def index_of_list_substring(items: List[str], substr: str) -> int: anon_config = mock.Mock() - with NamedTemporaryFile(delete=False, mode="w") as f: + reportfilename = os.path.join(self.tempdir.name, "tmpreport.pdf") + + with open(reportfilename, mode="w") as f: mock_db = mock.Mock( table_names=["anon_patient", "anon_note"], metadata=AnonTestBase.metadata, @@ -182,7 +184,7 @@ def index_of_list_substring(items: List[str], substr: str) -> int: __post_init__=mock.Mock(), ): report_config = ResearcherReportConfig( - output_filename=f.name, + output_filename=reportfilename, anonconfig=anon_config, use_dd=False, ) @@ -190,7 +192,7 @@ def index_of_list_substring(items: List[str], substr: str) -> int: report_config.db = mock_db mk_researcher_report_pdf(report_config) - with open(f.name, "rb") as f: + with open(reportfilename, "rb") as f: reader = PdfReader(f) patient_found = False @@ -212,7 +214,7 @@ def index_of_list_substring(items: List[str], substr: str) -> int: num_rows = int(lines[rows_index + 1]) table_name = lines[0] - if table_name == "patient": + if table_name == "anon_patient": patient_found = True self.assertEqual(num_rows, self.num_patients) diff --git a/crate_anon/crateweb/anonymise_api/tests.py b/crate_anon/crateweb/anonymise_api/tests.py index 338e2001..4b3a32b9 100644 --- a/crate_anon/crateweb/anonymise_api/tests.py +++ b/crate_anon/crateweb/anonymise_api/tests.py @@ -49,6 +49,8 @@ @override_settings(ANONYMISE_API=DEFAULT_SETTINGS) class AnonymisationTests(TestCase): + databases = {"default", "research"} + def setUp(self) -> None: super().setUp() diff --git a/crate_anon/crateweb/config/urls.py b/crate_anon/crateweb/config/urls.py index cbec02e2..b72b9667 100644 --- a/crate_anon/crateweb/config/urls.py +++ b/crate_anon/crateweb/config/urls.py @@ -79,9 +79,10 @@ and EnvVar.RUNNING_TESTS not in os.environ ): from crate_anon.crateweb.research.research_db_info import ( - research_database_info, + get_research_db_info, ) + research_database_info = get_research_db_info() research_database_info.get_colinfolist() log = logging.getLogger(__name__) diff --git a/crate_anon/crateweb/consent/models.py b/crate_anon/crateweb/consent/models.py index fdf51b9a..9d121caf 100644 --- a/crate_anon/crateweb/consent/models.py +++ b/crate_anon/crateweb/consent/models.py @@ -111,7 +111,7 @@ ) from crate_anon.crateweb.research.models import get_mpid from crate_anon.crateweb.research.research_db_info import ( - research_database_info, + get_research_db_info, ) from crate_anon.crateweb.userprofile.models import UserProfile @@ -2038,6 +2038,7 @@ def process_request_main(self) -> None: ) # delayed import # Translate to an NHS number + research_database_info = get_research_db_info() dbinfo = research_database_info.dbinfo_for_contact_lookup if self.lookup_nhs_number is not None: self.nhs_number = self.lookup_nhs_number diff --git a/crate_anon/crateweb/consent/views.py b/crate_anon/crateweb/consent/views.py index cc283bc6..62862fde 100644 --- a/crate_anon/crateweb/consent/views.py +++ b/crate_anon/crateweb/consent/views.py @@ -85,7 +85,7 @@ ) from crate_anon.crateweb.extra.pdf import serve_html_or_pdf from crate_anon.crateweb.research.research_db_info import ( - research_database_info, + get_research_db_info, ) from crate_anon.crateweb.research.models import PidLookup, get_mpid from crate_anon.crateweb.userprofile.models import UserProfile @@ -504,6 +504,7 @@ def submit_contact_request(request: HttpRequest) -> HttpResponse: Args: request: the :class:`django.http.request.HttpRequest` """ + research_database_info = get_research_db_info() dbinfo = research_database_info.dbinfo_for_contact_lookup if request.user.is_superuser: form = SuperuserSubmitContactRequestForm( @@ -579,6 +580,7 @@ def clinician_initiated_contact_request(request: HttpRequest) -> HttpResponse: Args: request: the :class:`django.http.request.HttpRequest` """ + research_database_info = get_research_db_info() dbinfo = research_database_info.dbinfo_for_contact_lookup email = request.user.email userprofile = UserProfile.objects.get(user=request.user) diff --git a/crate_anon/crateweb/research/archive_func.py b/crate_anon/crateweb/research/archive_func.py index 42ce60ea..f74ae2b0 100644 --- a/crate_anon/crateweb/research/archive_func.py +++ b/crate_anon/crateweb/research/archive_func.py @@ -46,7 +46,7 @@ archive_template_url, ) from crate_anon.crateweb.research.research_db_info import ( - research_database_info, + get_research_db_info, ) from crate_anon.crateweb.research.views import ( FN_SRCDB, @@ -151,6 +151,7 @@ def delimit_sql_identifier(identifer: str) -> str: """ Delimits (quotes) an SQL identifier, if required. """ + research_database_info = get_research_db_info() return research_database_info.grammar.quote_identifier_if_required( identifer ) diff --git a/crate_anon/crateweb/research/models.py b/crate_anon/crateweb/research/models.py index 6fe6cfd2..f2a692fe 100644 --- a/crate_anon/crateweb/research/models.py +++ b/crate_anon/crateweb/research/models.py @@ -48,7 +48,7 @@ register_for_json, ) from cardinal_pythonlib.reprfunc import simple_repr -from cardinal_pythonlib.sql.sql_grammar import format_sql, SqlGrammar +from cardinal_pythonlib.sql.sql_grammar import format_sql from cardinal_pythonlib.tsv import make_tsv_row from cardinal_pythonlib.django.function_cache import django_cache_function from django.db import connections, DatabaseError, models @@ -82,7 +82,7 @@ ) from crate_anon.crateweb.research.research_db_info import ( RESEARCH_DB_CONNECTION_NAME, - research_database_info, + get_research_db_info, SingleResearchDatabase, ) from crate_anon.crateweb.research.sql_writer import ( @@ -185,6 +185,7 @@ def database_last_updated(dbname: str) -> Optional[datetime.datetime]: are null, the function will return the minimum date possible. If there are no such tables, the function will return None. """ + research_database_info = get_research_db_info() try: dbinfo = research_database_info.get_dbinfo_by_name(dbname) except ValueError: @@ -1704,7 +1705,7 @@ def _get_select_mrid_column(self) -> Optional[ColumnId]: """ if not self._patient_conditions: return None - return research_database_info.get_linked_mrid_column( + return self._research_database_info.get_linked_mrid_column( self._patient_conditions[0].table_id ) @@ -1748,7 +1749,6 @@ def patient_id_query(self, with_order_by: bool = True) -> str: if not self._patient_conditions: return "" - grammar = research_database_info.grammar select_mrid_column = self._get_select_mrid_column() if not select_mrid_column.is_valid: log.warning( @@ -1758,6 +1758,10 @@ def patient_id_query(self, with_order_by: bool = True) -> str: # One way this can happen: (1) a user saves a PMQ; (2) the # administrator removes one of the databases! return "" + + research_database_info = get_research_db_info() + grammar = research_database_info.grammar + mrid_alias = "_mrid" sql = add_to_select( "", @@ -1847,7 +1851,7 @@ def all_queries(self, mrids: List[Any] = None) -> List[TableQueryArgs]: return queries def where_patient_clause( - self, table_id: TableId, grammar: SqlGrammar, mrids: List[Any] = None + self, table_id: TableId, mrids: List[Any] = None ) -> SqlArgsTupleType: """ Returns an SQL WHERE clauses similar to ``sometable.mrid IN (1, 2, 3)`` @@ -1857,15 +1861,13 @@ def where_patient_clause( Args: table_id: :class:`crate_anon.common.sql.TableId` for the table whose MRID column we will apply the ``WHERE`` clause to - grammar: :class:`cardinal_pythonlib.sql.sql_grammar.SqlGrammar` - to use mrids: list of MRIDs; if this is ``None`` or empty, use the patients fetched (live) by our :meth:`patient_id_query`. Returns: tuple: ``sql, args`` """ - mrid_column = research_database_info.get_mrid_column_from_table( + mrid_column = self._research_database_info.get_mrid_column_from_table( table_id ) if mrids: @@ -1881,6 +1883,8 @@ def where_patient_clause( # derived tables, subqueries, ... unless TOP, OFFSET or FOR XML # is specified." args = [] # type: List[Any] + research_database_info = get_research_db_info() + grammar = research_database_info.grammar sql = f"{mrid_column.identifier(grammar)} IN ({in_clause})" return sql, args @@ -1914,6 +1918,7 @@ def make_query( """ if not columns: raise ValueError("No columns specified") + research_database_info = get_research_db_info() grammar = research_database_info.grammar mrid_column = research_database_info.get_mrid_column_from_table( table_id @@ -1922,9 +1927,7 @@ def make_query( for c in columns: if c not in all_columns: all_columns.append(c) - where_clause, args = self.where_patient_clause( - table_id, grammar, mrids - ) + where_clause, args = self.where_patient_clause(table_id, mrids) select_elements = [SelectElement(column_id=col) for col in all_columns] where_conditions = [WhereCondition(raw_sql=where_clause)] sql = add_to_select( @@ -1946,6 +1949,7 @@ def output_cols_html(self) -> str: """ Returns all our output columns in HTML format. """ + research_database_info = get_research_db_info() grammar = research_database_info.grammar return prettify_sql_html( "\n".join( @@ -1961,6 +1965,7 @@ def pt_conditions_html(self) -> str: """ Returns all our patient WHERE conditions in HTML format. """ + research_database_info = get_research_db_info() grammar = research_database_info.grammar return prettify_sql_html( "\nAND ".join([wc.sql(grammar) for wc in self.patient_conditions]) @@ -2030,6 +2035,7 @@ def gen_data_finder_queries( :class:`TableQueryArgs` objects (q.v.) """ + research_database_info = get_research_db_info() grammar = research_database_info.grammar mrid_alias = "master_research_id" table_name_alias = "table_name" @@ -2053,9 +2059,7 @@ def gen_data_finder_queries( max_date = "NULL" # ... OK (at least in MySQL) to do: # SELECT col1, COUNT(*), NULL FROM table GROUP BY col1; - where_clause, args = self.where_patient_clause( - table_id, grammar, mrids - ) + where_clause, args = self.where_patient_clause(table_id, mrids) table_identifier = table_id.identifier(grammar) select_elements = [ SelectElement(column_id=mrid_col, alias=mrid_alias), @@ -2116,6 +2120,7 @@ def gen_monster_queries( :class:`TableQueryArgs` objects (q.v.) """ + research_database_info = get_research_db_info() grammar = research_database_info.grammar for ( table_id @@ -2123,9 +2128,7 @@ def gen_monster_queries( mrid_col = research_database_info.get_mrid_column_from_table( table=table_id ) - where_clause, args = self.where_patient_clause( - table_id, grammar, mrids - ) + where_clause, args = self.where_patient_clause(table_id, mrids) # We add the WHERE using our magic query machine, to get the joins # right: select_elements = [ diff --git a/crate_anon/crateweb/research/research_db_info.py b/crate_anon/crateweb/research/research_db_info.py index 82e4ec00..e1e92551 100644 --- a/crate_anon/crateweb/research/research_db_info.py +++ b/crate_anon/crateweb/research/research_db_info.py @@ -211,7 +211,6 @@ def __init__( grammar: SqlGrammar, rdb_info: "ResearchDatabaseInfo", connection: BaseDatabaseWrapper, - vendor: str, ) -> None: """ Instantiates, reading database information as follows: @@ -231,16 +230,15 @@ def __init__( the research database connection: a :class:`django.db.backends.base.base.BaseDatabaseWrapper`, - i.e. a Django database connection - vendor: - the Django database vendor name; see e.g. + i.e. a Django database connection. This includes + connection.vendor, the Django database vendor name; see e.g. https://docs.djangoproject.com/en/2.1/ref/models/options/ """ assert 0 <= index <= len(settings.RESEARCH_DB_INFO) infodict = settings.RESEARCH_DB_INFO[index] - self.connection = connection - self.vendor = vendor + # Don't store self.connection; the Django cache will pickle and it is + # not pickleable. self.index = index self.is_first_db = index == 0 @@ -389,9 +387,7 @@ def __init__( assert self.schema_id # Now discover the schema - self._schema_infodictlist = ( - None - ) # type: Optional[List[Dict[str, Any]]] + self._schema_infodictlist = self.get_schema_infodictlist(connection) self._colinfolist = None # type: Optional[List[ColumnInfo]] @property @@ -401,10 +397,6 @@ def schema_infodictlist(self) -> List[Dict[str, Any]]: :meth:`get_schema_infodictlist` for our connection and vendor. Implements caching. """ - if self._schema_infodictlist is None: - self._schema_infodictlist = self.get_schema_infodictlist( - self.connection, self.vendor - ) return self._schema_infodictlist @property @@ -524,7 +516,7 @@ def column_present(self, column_id: ColumnId) -> bool: return False # ------------------------------------------------------------------------- - # Fetching schema info from the database + # Fetching schema info from the database: internals # ------------------------------------------------------------------------- @classmethod @@ -532,7 +524,7 @@ def _schema_query_microsoft( cls, db_name: str, schema_names: List[str] ) -> SqlArgsTupleType: """ - Returns a query to fetche the database structure from an SQL Server + Returns a query to fetch the database structure from an SQL Server database. The columns returned are as expected by @@ -573,15 +565,20 @@ def _schema_query_microsoft( d.column_type, d.column_comment, CASE WHEN COUNT(d.index_id) > 0 THEN 1 ELSE 0 END AS indexed, - CASE WHEN COUNT(d.fulltext_index_object_id) > 0 THEN 1 ELSE 0 END AS indexed_fulltext + CASE + WHEN COUNT(d.fulltext_index_object_id) > 0 THEN 1 + ELSE 0 + END AS indexed_fulltext FROM ( SELECT s.name AS table_schema, ta.name AS table_name, c.name AS column_name, c.is_nullable, - UPPER(ty.name) + '(' + CONVERT(VARCHAR(100), c.max_length) + ')' AS column_type, - CONVERT(VARCHAR(1000), x.value) AS column_comment, -- x.value is of type SQL_VARIANT + UPPER(ty.name) + '(' + CONVERT(VARCHAR(100), c.max_length) + ')' + AS column_type, + CONVERT(VARCHAR(1000), x.value) AS column_comment, + -- x.value is of type SQL_VARIANT i.index_id, fi.object_id AS fulltext_index_object_id FROM [{db_name}].sys.tables ta @@ -601,7 +598,8 @@ def _schema_query_microsoft( AND fi.column_id = c.column_id ) WHERE s.name IN ({schema_placeholder}) - AND ty.user_type_id = ty.system_type_id -- restricts to system data types; eliminates 'sysname' type + AND ty.user_type_id = ty.system_type_id + -- restricts to system data types; eliminates 'sysname' type ) AS d GROUP BY table_schema, @@ -614,7 +612,7 @@ def _schema_query_microsoft( table_schema, table_name, column_name - """ # noqa: E501 + """ ) args = [db_name] + schema_names return sql, args @@ -622,7 +620,7 @@ def _schema_query_microsoft( @classmethod def _schema_query_mysql(cls, db_and_schema_name: str) -> SqlArgsTupleType: """ - Returns a query to fetche the database structure from a MySQL database. + Returns a query to fetch the database structure from a MySQL database. The columns returned are as expected by :func:`get_schema_infodictlist`. @@ -739,7 +737,7 @@ def _schema_query_mysql(cls, db_and_schema_name: str) -> SqlArgsTupleType: table_schema, table_name, column_name - """ + """ ) args = [db_and_schema_name] return sql, args @@ -749,7 +747,7 @@ def _schema_query_postgres( cls, schema_names: List[str] ) -> SqlArgsTupleType: """ - Returns a query to fetche the database structure from an SQL Server + Returns a query to fetch the database structure from an SQL Server database. The columns returned are as expected by @@ -844,13 +842,112 @@ def _schema_query_postgres( table_schema, table_name, column_name - """ + """ ) args = schema_names return sql, args + @classmethod + def _schema_query_sqlite_as_infodictlist( + cls, connection: BaseDatabaseWrapper, debug: bool = False + ) -> List[Dict[str, Any]]: + """ + Queries an SQLite databases and returns columns as expected by + :func:`get_schema_infodictlist`. + """ + # 1. Catalogue tables. + # pragma table_info(sqlite_master); + empty_args = [] + sql_get_tables = """ + SELECT tbl_name AS tablename + FROM sqlite_master + WHERE type='table' + """ + table_info_rows = cls._exec_sql_query( + connection, (sql_get_tables, empty_args), debug=debug + ) + table_names = [row["tablename"] for row in table_info_rows] + + # 2. Catalogue each tables + results = [] # type: List[Dict[str, Any]] + for table_name in table_names: + # A "PRAGMA table_info()" call doesn't work with arguments. + sql_inspect_table = f"PRAGMA table_info({table_name})" + column_info_rows = cls._exec_sql_query( + connection, (sql_inspect_table, empty_args), debug=debug + ) + for ci in column_info_rows: + results.append( + dict( + table_catalog="", + table_schema="", + table_name=table_name, + column_name=ci["name"], + is_nullable=1 - ci["notnull"], + column_type=ci["type"], + column_comment="", + indexed=0, + indexed_fulltext=0, + ) + ) + # Ignored: + # - "cid" (column ID) + # - "dflt_value" + # - "pk" + return results + + @classmethod + def _exec_sql_query( + cls, + connection: BaseDatabaseWrapper, + sql_args: SqlArgsTupleType, + debug: bool = False, + ) -> List[Dict[str, Any]]: + """ + Used by get_schema_infodictlist() as a common function to translate an + sql/args pair into the desired results. But it does that because the + incoming SQL has the right column names; the function is more generic + and just runs a query. + + Args: + connection: + a :class:`django.db.backends.base.base.BaseDatabaseWrapper`, + i.e. a Django database connection + sql_args: + tuple of SQL and arguments + debug: + be verbose to the log? + + Returns: + A list of dictionaries, each mapping column names to values. + The dictionaries are suitable for use as ``**kwargs`` to + :class:`ColumnInfo`. + """ + # We execute this one directly, rather than using the Query class, + # since this is a system rather than a per-user query. + sql, args = sql_args + cursor = connection.cursor() + if debug: + log.debug(f"- sql = {sql}\n- args = {args!r}") + cursor.execute(sql, args) + # Re passing multiple values to SQL via args: + # - Don't circumvent the parameter protection against SQL injection. + # - Too much hassle to use Django's ORM model here, though that would + # also be possible. + # - https://stackoverflow.com/questions/907806 + # - Similarly via SQLAlchemy reflection/inspection. + results = dictfetchall(cursor) # list of OrderedDicts + if debug: + log.debug(f"results = {results!r}") + log.debug("... done") + return results + + # ------------------------------------------------------------------------- + # Fetching schema info from the database: main (still internal) interface + # ------------------------------------------------------------------------- + def get_schema_infodictlist( - self, connection: BaseDatabaseWrapper, vendor: str, debug: bool = False + self, connection: BaseDatabaseWrapper, debug: bool = False ) -> List[Dict[str, Any]]: """ Fetch structure information for a specific database, by asking the @@ -860,9 +957,6 @@ def get_schema_infodictlist( connection: a :class:`django.db.backends.base.base.BaseDatabaseWrapper`, i.e. a Django database connection - vendor: - the Django database vendor name; see e.g. - https://docs.djangoproject.com/en/2.1/ref/models/options/ debug: be verbose to the log? @@ -879,49 +973,53 @@ def get_schema_infodictlist( f"{db_name!r}, schema {schema_name!r})..." ) # The db/schema names are guaranteed to be strings by __init__(). - if vendor == ConnectionVendors.MICROSOFT: + if connection.vendor == ConnectionVendors.MICROSOFT: if not db_name: - raise ValueError("No db_name specified; required for MSSQL") + raise ValueError(f"{db_name=!r}; required for MSSQL") if not schema_name: - raise ValueError( - "No schema_name specified; required for MSSQL" - ) - sql, args = self._schema_query_microsoft(db_name, [schema_name]) - elif vendor == ConnectionVendors.POSTGRESQL: + raise ValueError(f"{schema_name=!r}; required for MSSQL") + results = self._exec_sql_query( + connection, + sql_args=self._schema_query_microsoft(db_name, [schema_name]), + debug=debug, + ) + elif connection.vendor == ConnectionVendors.POSTGRESQL: if db_name: - raise ValueError( - "db_name specified; must be '' for PostgreSQL" - ) + raise ValueError(f"{db_name=!r}; must be '' for PostgreSQL") if not schema_name: - raise ValueError( - "No schema_name specified; required for PostgreSQL" - ) - sql, args = self._schema_query_postgres([schema_name]) - elif vendor == ConnectionVendors.MYSQL: + raise ValueError(f"{schema_name=!r}; required for PostgreSQL") + results = self._exec_sql_query( + connection, + sql_args=self._schema_query_postgres([schema_name]), + debug=debug, + ) + elif connection.vendor == ConnectionVendors.MYSQL: if db_name: - raise ValueError("db_name specified; must be '' for MySQL") + raise ValueError(f"{db_name=!r}; must be '' for MySQL") if not schema_name: - raise ValueError( - "No schema_name specified; required for MySQL" - ) - sql, args = self._schema_query_mysql( - db_and_schema_name=schema_name + raise ValueError(f"{schema_name=!r}; required for MySQL") + results = self._exec_sql_query( + connection, + sql_args=self._schema_query_mysql( + db_and_schema_name=schema_name + ), + debug=debug, + ) + elif connection.vendor == ConnectionVendors.SQLITE: + # db_name: don't care? + # schema_name: don't care? + # This one can't be done as a single query; the following function + # builds up the information by querying a list of tables, then each + # table. + results = self._schema_query_sqlite_as_infodictlist( + connection, debug=debug ) else: raise ValueError( f"Don't know how to get metadata for " - f"connection.vendor=='{vendor}'" + f"{connection.vendor=!r}" ) - # We execute this one directly, rather than using the Query class, - # since this is a system rather than a per-user query. - cursor = connection.cursor() - if debug: - log.debug(f"sql = {sql}, args = {args!r}") - cursor.execute(sql, args) - results = dictfetchall(cursor) # list of OrderedDicts - if debug: - log.debug(f"results = {results!r}") - log.debug("... done") + if not results: log.warning( f"SingleResearchDatabase.get_schema_infodictlist(): no " @@ -929,12 +1027,6 @@ def get_schema_infodictlist( f"database - misconfigured?" ) return results - # Re passing multiple values to SQL via args: - # - Don't circumvent the parameter protection against SQL injection. - # - Too much hassle to use Django's ORM model here, though that would - # also be possible. - # - https://stackoverflow.com/questions/907806 - # - Similarly via SQLAlchemy reflection/inspection. @register_for_json(method=METHOD_NO_ARGS) @@ -954,10 +1046,10 @@ class ResearchDatabaseInfo: # We fetch the dialect at first request; this enables us to import the # class without Django configured. - def __init__(self) -> None: + def __init__(self, running_without_config: bool = False) -> None: self.dbinfolist = [] # type: List[SingleResearchDatabase] - if RUNNING_WITHOUT_CONFIG: + if running_without_config: self.dialect = "" self.grammar = None # type: Optional[SqlGrammar] self.dbinfo_for_contact_lookup = ( @@ -973,7 +1065,6 @@ def __init__(self) -> None: self.grammar = make_grammar(self.dialect) # not expensive connection = self._connection() - vendor = connection.vendor for index in range(len(settings.RESEARCH_DB_INFO)): self.dbinfolist.append( @@ -982,7 +1073,6 @@ def __init__(self) -> None: grammar=self.grammar, rdb_info=self, connection=connection, - vendor=vendor, ) ) assert ( @@ -1597,4 +1687,6 @@ def get_mrid_linkable_patient_tables(self) -> List[TableId]: return sorted(list(eligible_tables)) -research_database_info = ResearchDatabaseInfo() +@django_cache_function(timeout=None) +def get_research_db_info() -> ResearchDatabaseInfo: + return ResearchDatabaseInfo(RUNNING_WITHOUT_CONFIG) diff --git a/crate_anon/crateweb/research/sql_writer.py b/crate_anon/crateweb/research/sql_writer.py index 269fb65d..c70744a3 100644 --- a/crate_anon/crateweb/research/sql_writer.py +++ b/crate_anon/crateweb/research/sql_writer.py @@ -50,7 +50,7 @@ ) from crate_anon.crateweb.research.errors import DatabaseStructureNotUnderstood from crate_anon.crateweb.research.research_db_info import ( - research_database_info, + get_research_db_info, ) log = logging.getLogger(__name__) @@ -105,6 +105,7 @@ def get_join_info( - ``INNER JOIN`` etc. is part of ANSI SQL """ + research_database_info = get_research_db_info() first_from_table = get_first_from_table(parsed) from_table_in_join_schema = get_first_from_table( parsed, match_db=jointable.db, match_schema=jointable.schema diff --git a/crate_anon/crateweb/research/tests/research_db_info_tests.py b/crate_anon/crateweb/research/tests/research_db_info_tests.py new file mode 100644 index 00000000..bbc3eee0 --- /dev/null +++ b/crate_anon/crateweb/research/tests/research_db_info_tests.py @@ -0,0 +1,167 @@ +""" +crate_anon/crateweb/research/tests/research_db_info_tests.py + +=============================================================================== + + Copyright (C) 2015, University of Cambridge, Department of Psychiatry. + Created by Rudolf Cardinal (rnc1001@cam.ac.uk). + + This file is part of CRATE. + + CRATE is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + CRATE is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with CRATE. If not, see . + +=============================================================================== + +Test research_db_info.py. + +""" + +# ============================================================================= +# Imports +# ============================================================================= + +import logging +import os.path +from tempfile import TemporaryDirectory + +from cardinal_pythonlib.dbfunc import dictfetchall +from cardinal_pythonlib.sql.sql_grammar import SqlGrammar +from django.db import connections, DEFAULT_DB_ALIAS +from django.test.testcases import TestCase # inherits from unittest.TestCase + +from crate_anon.crateweb.config.constants import ResearchDbInfoKeys as RDIKeys +from crate_anon.crateweb.research.research_db_info import ( + SingleResearchDatabase, + ResearchDatabaseInfo, +) + +log = logging.getLogger(__name__) + + +# ============================================================================= +# Constants +# ============================================================================= + + +# ============================================================================= +# Unit tests +# ============================================================================= +class ResearchDBInfoTests(TestCase): + _research_django_db_name = "research" + + databases = {DEFAULT_DB_ALIAS, _research_django_db_name} + # ... or the test framework will produce this: + # + # django.test.testcases.DatabaseOperationForbidden: Database queries to + # 'research' are not allowed in this test. Add 'research' to + # research_db_info_tests.ResearchDBInfoTests.databases to ensure proper + # test isolation and silence this failure. + # + # It is checked by a classmethod, not an instance. + + def setUp(self): + super().setUp() + + # crate_anon.common.constants.RUNNING_WITHOUT_CONFIG = True + + # If we have two SQLite in-memory database (with name = ":memory:"), + # they appear to be the same database. But equally, if you use a local + # temporary directory, nothing is created on disk; so presumably the + # Django test framework is intercepting everything? + self.tempdir = TemporaryDirectory() # will be deleted on destruction + self.settings( + DATABASES={ + DEFAULT_DB_ALIAS: { + "ENGINE": "django.db.backends.sqlite3", + "NAME": os.path.join(self.tempdir.name, "main.sqlite3"), + }, + self._research_django_db_name: { + "ENGINE": "django.db.backends.sqlite3", + "NAME": os.path.join( + self.tempdir.name, "research.sqlite3" + ), + }, + }, + # DEBUG=True, + RESEARCH_DB_INFO=[ + { + RDIKeys.NAME: "research", + RDIKeys.DESCRIPTION: "Demo research database", + RDIKeys.DATABASE: "", + RDIKeys.SCHEMA: "research", + RDIKeys.PID_PSEUDO_FIELD: "pid", + RDIKeys.MPID_PSEUDO_FIELD: "mpid", + RDIKeys.TRID_FIELD: "trid", + RDIKeys.RID_FIELD: "brcid", + RDIKeys.RID_FAMILY: 1, + RDIKeys.MRID_TABLE: "patients", + RDIKeys.MRID_FIELD: "nhshash", + RDIKeys.PID_DESCRIPTION: "Patient ID", + RDIKeys.MPID_DESCRIPTION: "Master patient ID", + RDIKeys.RID_DESCRIPTION: "Research ID", + RDIKeys.MRID_DESCRIPTION: "Master research ID", + RDIKeys.TRID_DESCRIPTION: "Transient research ID", + RDIKeys.SECRET_LOOKUP_DB: "secret", + RDIKeys.DATE_FIELDS_BY_TABLE: {}, + RDIKeys.DEFAULT_DATE_FIELDS: [], + RDIKeys.UPDATE_DATE_FIELD: "_when_fetched_utc", + }, + ], + ) + self.mainconn = connections[DEFAULT_DB_ALIAS] + self.resconn = connections[self._research_django_db_name] + self.grammar = SqlGrammar() + with self.resconn.cursor() as cursor: + cursor.execute("CREATE TABLE t (a INT, b INT)") + cursor.execute("INSERT INTO t (a, b) VALUES (1, 101)") + cursor.execute("INSERT INTO t (a, b) VALUES (2, 102)") + cursor.execute("COMMIT") + + def tearDown(self) -> None: + with self.resconn.cursor() as cursor: + cursor.execute("DROP TABLE t") + # Otherwise, you can run one test, but if you run two, you get: + # + # django.db.transaction.TransactionManagementError: An error occurred + # in the current transaction. You can't execute queries until the end + # of the 'atomic' block. + # + # ... no - still the problem! + # Hack: combine the tests. + + def test_django_dummy_database_and_sqlite_schema_reader(self) -> None: + with self.resconn.cursor() as cursor: + cursor.execute("SELECT * FROM t") + results = dictfetchall(cursor) + self.assertEqual(len(results), 2) + self.assertEqual(results[0], dict(a=1, b=101)) + self.assertEqual(results[1], dict(a=2, b=102)) + + rdbi = ResearchDatabaseInfo(running_without_config=True) + srd = SingleResearchDatabase( + index=0, + grammar=self.grammar, + rdb_info=rdbi, + connection=self.resconn, + ) + col_info_list = srd.schema_infodictlist # will read the database + # Unfortunately it will read all the Django tables too (see above). + table_t_cols = [c for c in col_info_list if c["table_name"] == "t"] + self.assertTrue(len(table_t_cols) == 2) + row0 = table_t_cols[0] + self.assertEqual(row0["column_name"], "a") + self.assertEqual(row0["column_type"], "INT") + row1 = table_t_cols[1] + self.assertEqual(row1["column_name"], "b") + self.assertEqual(row1["column_type"], "INT") diff --git a/crate_anon/crateweb/research/views.py b/crate_anon/crateweb/research/views.py index ea661dca..3dc1d9b7 100644 --- a/crate_anon/crateweb/research/views.py +++ b/crate_anon/crateweb/research/views.py @@ -70,7 +70,6 @@ from crate_anon.common.constants import JSON_SEPARATORS_COMPACT -# from crate_anon.common.profiling import do_cprofile from crate_anon.common.sql import ( ColumnId, escape_sql_string_literal, @@ -145,7 +144,7 @@ ) from crate_anon.crateweb.research.research_db_info import ( PatientFieldPythonTypes, - research_database_info, + get_research_db_info, SingleResearchDatabase, ) from crate_anon.crateweb.research.sql_writer import ( @@ -280,6 +279,7 @@ def get_db_structure_json() -> str: Returns the research database structure in JSON format. """ log.debug("get_db_structure_json") + research_database_info = get_research_db_info() colinfolist = research_database_info.get_colinfolist() if not colinfolist: log.warning("get_db_structure_json(): colinfolist is empty") @@ -398,6 +398,7 @@ def query_build(request: HttpRequest) -> HttpResponse: # noinspection PyUnresolvedReferences profile = request.user.profile # type: UserProfile parse_error = "" + research_database_info = get_research_db_info() default_database = research_database_info.get_default_database_name() default_schema = research_database_info.get_default_schema_name() with_database = research_database_info.uses_database_level() @@ -688,6 +689,7 @@ def parse_privileged_sql(request: HttpRequest, sql: str) -> List[Any]: sql_components = sql.split() new_sql = "" i = 0 + research_database_info = get_research_db_info() while i < len(sql_components): split_component = sql_components[i].split(":") if len(split_component) == 2 and ( @@ -911,6 +913,7 @@ def query_edit_select(request: HttpRequest) -> HttpResponse: # Have to use plain sql for this (not coloured) in case it cuts it # off after an html start tag but before the end tag q.truncated_sql = sql[:50] + research_database_info = get_research_db_info() context = { "form": form, "queries": queries, @@ -1214,6 +1217,7 @@ def get_source_results( Returns: a :class:`NlpSourceResult` """ + research_database_info = get_research_db_info() try: dbname = research_database_info.nlp_sourcedb_map[srcdb] except KeyError: @@ -2222,6 +2226,7 @@ def pid_rid_lookup( Returns: a :class:`django.http.response.HttpResponse` """ + research_database_info = get_research_db_info() dbinfolist = research_database_info.dbs_with_secret_map n = len(dbinfolist) if n == 0: @@ -2269,6 +2274,7 @@ def pid_rid_lookup_with_db( # Union[Type[PidLookupForm], Type[RidLookupForm]] yet; we get # TypeError: descriptor '__subclasses__' of 'type' object needs an argument # ... see https://github.com/python/typing/issues/266 + research_database_info = get_research_db_info() try: dbinfo = research_database_info.get_dbinfo_by_name(dbname) except ValueError: @@ -2463,6 +2469,7 @@ def structure_table_long(request: HttpRequest) -> HttpResponse: Returns: a :class:`django.http.response.HttpResponse` """ + research_database_info = get_research_db_info() colinfolist = research_database_info.get_colinfolist() rowcount = len(colinfolist) context = { @@ -2486,6 +2493,7 @@ def structure_table_paginated(request: HttpRequest) -> HttpResponse: Returns: a :class:`django.http.response.HttpResponse` """ + research_database_info = get_research_db_info() colinfolist = research_database_info.get_colinfolist() rowcount = len(colinfolist) colinfolist = paginate(request, colinfolist) @@ -2510,6 +2518,7 @@ def get_structure_tree_html() -> str: Returns: str: HTML """ + research_database_info = get_research_db_info() table_to_colinfolist = research_database_info.get_colinfolist_by_tables() content = "" element_counter = HtmlElementCounter() @@ -2551,6 +2560,7 @@ def structure_tree(request: HttpRequest) -> HttpResponse: Returns: a :class:`django.http.response.HttpResponse` """ + research_database_info = get_research_db_info() context = { "content": get_structure_tree_html(), "default_database": research_database_info.get_default_database_name(), @@ -2570,6 +2580,7 @@ def structure_tsv(request: HttpRequest) -> HttpResponse: Returns: a :class:`django.http.response.HttpResponse` """ + research_database_info = get_research_db_info() return file_response( research_database_info.get_tsv(), content_type=ContentType.TSV, @@ -2589,6 +2600,7 @@ def structure_excel(request: HttpRequest) -> HttpResponse: Returns: a :class:`django.http.response.HttpResponse` """ + research_database_info = get_research_db_info() return file_response( research_database_info.get_excel(), content_type=ContentType.TSV, @@ -2721,6 +2733,7 @@ def textfinder_sql( raise ValueError( "Must supply either 'fragment' or 'drug_type' to 'textfinder_sql'" ) + research_database_info = get_research_db_info() grammar = research_database_info.grammar tables = research_database_info.tables_containing_field( patient_id_fieldname @@ -3096,6 +3109,7 @@ def sqlhelper_text_anywhere(request: HttpRequest) -> HttpResponse: Returns: a :class:`django.http.response.HttpResponse` """ + research_database_info = get_research_db_info() if research_database_info.single_research_db: dbname = research_database_info.first_dbinfo.name return HttpResponseRedirect( @@ -3132,6 +3146,7 @@ def sqlhelper_text_anywhere_with_db( Returns: a :class:`django.http.response.HttpResponse` """ + research_database_info = get_research_db_info() try: dbinfo = research_database_info.get_dbinfo_by_name(dbname) except ValueError: @@ -3164,6 +3179,7 @@ def sqlhelper_drug_type(request: HttpRequest) -> HttpResponse: Returns: a :class:`django.http.response.HttpResponse` """ + research_database_info = get_research_db_info() if research_database_info.single_research_db: dbname = research_database_info.first_dbinfo.name return HttpResponseRedirect( @@ -3197,6 +3213,7 @@ def sqlhelper_drug_type_with_db( Returns: a :class:`django.http.response.HttpResponse` """ + research_database_info = get_research_db_info() try: dbinfo = research_database_info.get_dbinfo_by_name(dbname) except ValueError: @@ -3229,6 +3246,7 @@ def all_text_from_pid(request: HttpRequest) -> HttpResponse: Returns: a :class:`django.http.response.HttpResponse` """ + research_database_info = get_research_db_info() dbinfolist = research_database_info.dbs_with_secret_map n = len(dbinfolist) if n == 0: @@ -3267,6 +3285,7 @@ def all_text_from_pid_with_db( Returns: a :class:`django.http.response.HttpResponse` """ + research_database_info = get_research_db_info() try: dbinfo = research_database_info.get_dbinfo_by_name(dbname) except ValueError: @@ -3304,6 +3323,7 @@ def pe_build(request: HttpRequest) -> HttpResponse: a :class:`django.http.response.HttpResponse` """ + research_database_info = get_research_db_info() # noinspection PyUnresolvedReferences profile = request.user.profile # type: UserProfile default_database = research_database_info.get_default_database_name() @@ -3562,6 +3582,7 @@ def pe_results(request: HttpRequest, pe_id: str) -> HttpResponse: a :class:`django.http.response.HttpResponse` """ pe = get_object_or_404(PatientExplorer, id=pe_id) # type: PatientExplorer + research_database_info = get_research_db_info() grammar = research_database_info.grammar # noinspection PyUnresolvedReferences profile = request.user.profile # type: UserProfile @@ -3948,6 +3969,7 @@ def pe_monster_results(request: HttpRequest, pe_id: str) -> HttpResponse: """ pe = get_object_or_404(PatientExplorer, id=pe_id) # type: PatientExplorer + research_database_info = get_research_db_info() grammar = research_database_info.grammar # noinspection PyUnresolvedReferences profile = request.user.profile # type: UserProfile @@ -4026,6 +4048,7 @@ def pe_table_browser(request: HttpRequest, pe_id: str) -> HttpResponse: """ pe = get_object_or_404(PatientExplorer, id=pe_id) # type: PatientExplorer + research_database_info = get_research_db_info() tables = research_database_info.get_tables() with_database = research_database_info.uses_database_level() try: @@ -4071,6 +4094,7 @@ def pe_one_table( """ pe = get_object_or_404(PatientExplorer, id=pe_id) # type: PatientExplorer + research_database_info = get_research_db_info() table_id = TableId(db=db, schema=schema, table=table) grammar = research_database_info.grammar highlights = Highlight.get_active_highlights(request) diff --git a/devnotes/2025_01_sqlalchemy2_databricks/pipeline_test.sh b/devnotes/2025_01_sqlalchemy2_databricks/pipeline_test.sh index fcca045b..04cc8b52 100755 --- a/devnotes/2025_01_sqlalchemy2_databricks/pipeline_test.sh +++ b/devnotes/2025_01_sqlalchemy2_databricks/pipeline_test.sh @@ -1,5 +1,7 @@ #!/bin/bash +# Run a battery of CRATE tests locally. + set -ex if [ -z "$TMP_CRATE_DEMO_DATABASE_URL" ]; then @@ -7,6 +9,10 @@ if [ -z "$TMP_CRATE_DEMO_DATABASE_URL" ]; then exit 1 fi +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +TEST_ROOT_DIR="${SCRIPT_DIR}/../.." +cd "${TEST_ROOT_DIR}" + crate_make_demo_database "${TMP_CRATE_DEMO_DATABASE_URL}" crate_anon_draft_dd crate_anonymise --full