From ccc906d401a82e65d1d635d58326e60efda31776 Mon Sep 17 00:00:00 2001 From: DaveOkpare Date: Sat, 18 Mar 2023 11:19:12 +0100 Subject: [PATCH] Refactor run function and added unit test --- src/sql/run.py | 111 ++++++++++++++++++++------------- src/tests/test_run.py | 142 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 209 insertions(+), 44 deletions(-) create mode 100644 src/tests/test_run.py diff --git a/src/sql/run.py b/src/sql/run.py index 24dd376cf..1c978e302 100644 --- a/src/sql/run.py +++ b/src/sql/run.py @@ -386,53 +386,76 @@ def _commit(conn, config, manual_commit): print("The database does not support the COMMIT command") -def run(conn, sql, config): - if sql.strip(): - for statement in sqlparse.split(sql): - first_word = sql.strip().split()[0].lower() - manual_commit = False - if first_word == "begin": - raise Exception("ipython_sql does not support transactions") - if first_word.startswith("\\") and ( - "postgres" in str(conn.dialect) or "redshift" in str(conn.dialect) - ): - if not PGSpecial: - raise ImportError("pgspecial not installed") - pgspecial = PGSpecial() - _, cur, headers, _ = pgspecial.execute( - conn.session.connection.cursor(), statement - )[0] - result = FakeResultProxy(cur, headers) - else: - txt = sqlalchemy.sql.text(statement) - if config.autocommit: - try: - conn.session.execution_options(isolation_level="AUTOCOMMIT") - except Exception as e: - logging.debug( - f"The database driver doesn't support such " - f"AUTOCOMMIT execution option" - f"\nPerhaps you can try running a manual COMMIT command" - f"\nMessage from the database driver\n\t" - f"Exception: {e}\n", # noqa: F841 - ) - manual_commit = True - result = conn.session.execute(txt) - _commit(conn=conn, config=config, manual_commit=manual_commit) - if result and config.feedback: - print(interpret_rowcount(result.rowcount)) - - resultset = ResultSet(result, config) - if config.autopandas: - return resultset.DataFrame() - elif config.autopolars: - return resultset.PolarsDataFrame() - else: - return resultset - # returning only last result, intentionally +def is_postgres_or_redshift(dialect): + """Checks if dialect is postgres or redshift""" + return "postgres" in str(dialect) or "redshift" in str(dialect) + + +def handle_postgres_special(conn, statement): + """Execute a PostgreSQL special statement using PGSpecial module.""" + if not PGSpecial: + raise ImportError("pgspecial not installed") + pgspecial = PGSpecial() + _, cur, headers, _ = pgspecial.execute(conn.session.connection.cursor(), statement)[ + 0 + ] + return FakeResultProxy(cur, headers) + + +def set_autocommit(conn, config): + """Sets the autocommit setting for a database connection.""" + if config.autocommit: + try: + conn.session.execution_options(isolation_level="AUTOCOMMIT") + except Exception as e: + logging.debug( + f"The database driver doesn't support such " + f"AUTOCOMMIT execution option" + f"\nPerhaps you can try running a manual COMMIT command" + f"\nMessage from the database driver\n\t" + f"Exception: {e}\n", # noqa: F841 + ) + return True + return False + + +def select_df_type(resultset, config): + """ + Converts the input resultset to either a Pandas DataFrame + or Polars DataFrame based on the config settings. + """ + if config.autopandas: + return resultset.DataFrame() + elif config.autopolars: + return resultset.PolarsDataFrame() else: + return resultset + # returning only last result, intentionally + + +def run(conn, sql, config): + if not sql.strip(): + # returning only when sql is empty string return "Connected: %s" % conn.name + for statement in sqlparse.split(sql): + first_word = sql.strip().split()[0].lower() + manual_commit = False + if first_word == "begin": + raise ValueError("ipython_sql does not support transactions") + if first_word.startswith("\\") and is_postgres_or_redshift(conn.dialect): + result = handle_postgres_special(conn, statement) + else: + txt = sqlalchemy.sql.text(statement) + manual_commit = set_autocommit(conn, config) + result = conn.session.execute(txt) + _commit(conn=conn, config=config, manual_commit=manual_commit) + if result and config.feedback: + print(interpret_rowcount(result.rowcount)) + + resultset = ResultSet(result, config) + return select_df_type(resultset, config) + class PrettyTable(prettytable.PrettyTable): def __init__(self, *args, **kwargs): diff --git a/src/tests/test_run.py b/src/tests/test_run.py new file mode 100644 index 000000000..f43201c43 --- /dev/null +++ b/src/tests/test_run.py @@ -0,0 +1,142 @@ +import logging +from unittest.mock import Mock + +import pandas +import polars +import pytest + +from sql.connection import Connection +from sql.run import ( + run, + handle_postgres_special, + is_postgres_or_redshift, + select_df_type, + set_autocommit, + interpret_rowcount, +) + + +@pytest.fixture +def mock_conns(): + Connection.name = str() + Connection.dialect = "postgres" + return Connection + + +@pytest.fixture +def mock_config(): + class Config: + autopandas = None + autopolars = None + autocommit = True + feedback = True + + return Config + + +@pytest.fixture +def config_pandas(mock_config): + mock_config.autopandas = True + mock_config.autopolars = False + + return mock_config + + +@pytest.fixture +def config_polars(mock_config): + mock_config.autopandas = False + mock_config.autopolars = True + + return mock_config + + +@pytest.fixture +def mock_resultset(): + class ResultSet: + def __init__(self, *args, **kwargs): + ... + + @classmethod + def DataFrame(cls): + return pandas.DataFrame() + + @classmethod + def PolarsDataFrame(cls): + return polars.DataFrame() + + return ResultSet + + +@pytest.mark.parametrize( + "dialect", + [ + "postgres", + "redshift", + ], +) +def test_is_postgres_or_redshift(dialect): + assert is_postgres_or_redshift(dialect) is True + + +def test_handle_postgres_special(mock_conns): + with pytest.raises(ImportError): + handle_postgres_special(mock_conns, "\\") + + +def test_set_autocommit(mock_conns, mock_config, caplog): + caplog.set_level(logging.DEBUG) + output = set_autocommit(mock_conns, mock_config) + assert "The database driver doesn't support such " in caplog.records[0].msg + assert output is True + + +def test_select_df_type_is_pandas(monkeypatch, config_pandas, mock_resultset): + monkeypatch.setattr("sql.run.select_df_type", mock_resultset.DataFrame()) + output = select_df_type(mock_resultset, config_pandas) + assert isinstance(output, pandas.DataFrame) + + +def test_select_df_type_is_polars(monkeypatch, config_polars, mock_resultset): + monkeypatch.setattr("sql.run.select_df_type", mock_resultset.PolarsDataFrame()) + output = select_df_type(mock_resultset, config_polars) + assert isinstance(output, polars.DataFrame) + + +def test_sql_starts_with_begin(mock_conns, mock_config): + with pytest.raises(ValueError, match="does not support transactions"): + run(mock_conns, "BEGIN", mock_config) + + +def test_sql_is_empty(mock_conns, mock_config): + assert run(mock_conns, " ", mock_config) == "Connected: %s" % mock_conns.name + + +def test_run(monkeypatch, mock_conns, mock_resultset, config_pandas): + monkeypatch.setattr("sql.run.handle_postgres_special", Mock()) + monkeypatch.setattr("sql.run._commit", Mock()) + monkeypatch.setattr("sql.run.interpret_rowcount", Mock()) + monkeypatch.setattr("sql.run.ResultSet", mock_resultset) + + output = run(mock_conns, "\\", config_pandas) + assert isinstance(output, type(mock_resultset.DataFrame())) + + +def test_interpret_rowcount(): + assert interpret_rowcount(-1) == "Done." + assert interpret_rowcount(1) == "%d rows affected." % 1 + + +def test__commit_is_called( + monkeypatch, + mock_conns, + mock_config, +): + mock__commit = Mock() + monkeypatch.setattr("sql.run._commit", mock__commit) + monkeypatch.setattr("sql.run.handle_postgres_special", Mock()) + monkeypatch.setattr("sql.run.interpret_rowcount", Mock()) + monkeypatch.setattr("sql.run.ResultSet", Mock()) + + run(mock_conns, "\\", mock_config) + + mock__commit.assert_called()