diff --git a/CHANGELOG.md b/CHANGELOG.md index d4376ccd1..f141465d4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,8 @@ * [Feature] Upgrades SQLAlchemy version to 2 +* [Fix] `%sqlcmd --test` improved, changes in logic and addition of user guide (#275) + ## 0.7.0 (2023-04-05) JupySQL is now available via `conda install jupysql -c conda-forge`. Thanks, [@sterlinm](https://github.com/sterlinm)! diff --git a/doc/_toc.yml b/doc/_toc.yml index af4ebfd8b..3928962e8 100644 --- a/doc/_toc.yml +++ b/doc/_toc.yml @@ -11,6 +11,7 @@ parts: - file: plot - file: compose - file: user-guide/tables-columns + - file: user-guide/testing-columns - file: plot-legacy - file: user-guide/template - file: user-guide/interactive diff --git a/doc/integrations/mindsdb.ipynb b/doc/integrations/mindsdb.ipynb index d30728de4..d7aaf7d6a 100644 --- a/doc/integrations/mindsdb.ipynb +++ b/doc/integrations/mindsdb.ipynb @@ -704,7 +704,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.5" + "version": "3.10.9" }, "vscode": { "interpreter": { diff --git a/doc/integrations/mssql.ipynb b/doc/integrations/mssql.ipynb index 34ad6b5cf..a1bada570 100644 --- a/doc/integrations/mssql.ipynb +++ b/doc/integrations/mssql.ipynb @@ -1117,7 +1117,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.10" + "version": "3.10.9" }, "myst": { "html_meta": { diff --git a/doc/integrations/postgres-connect.ipynb b/doc/integrations/postgres-connect.ipynb index 6e5ddfb61..33794115a 100644 --- a/doc/integrations/postgres-connect.ipynb +++ b/doc/integrations/postgres-connect.ipynb @@ -1058,7 +1058,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.10" + "version": "3.10.9" }, "myst": { "html_meta": { diff --git a/doc/user-guide/tables-columns.md b/doc/user-guide/tables-columns.md index 2f48d553f..adf778720 100644 --- a/doc/user-guide/tables-columns.md +++ b/doc/user-guide/tables-columns.md @@ -103,17 +103,3 @@ Get the columns for the table in the newly created schema: ```{code-cell} ipython3 %sqlcmd columns --table numbers --schema some_schema ``` - -## Run Tests on Column - -Use `%sqlcmd test` to run tests on your dataset. - -For example, to see if all the values in the column birth_year are greater than 100: - -```{code-cell} ipython3 -%sqlcmd test --table people --column birth_year --greater 100 -``` - -Four different comparator commands exist: `greater`, `greater-or-equal`, `less-than`, `less-than-or-equal`, and `no-nulls`. - -Command will return True if all tests pass, otherwise an error with sample breaking cases will be printed out. diff --git a/doc/user-guide/testing-columns.md b/doc/user-guide/testing-columns.md new file mode 100644 index 000000000..07d48f4cc --- /dev/null +++ b/doc/user-guide/testing-columns.md @@ -0,0 +1,82 @@ +--- +jupytext: + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.14.4 +kernelspec: + display_name: Python 3 (ipykernel) + language: python + name: python3 +myst: + html_meta: + description lang=en: Test columns from your database in Jupyter via JupySQL + keywords: jupyter, sql, jupysql + property=og:locale: en_US +--- + + +# Testing with sqlcmd + +```{note} +This example uses `SQLite` but the same commands work for other databases. +``` + +```{code-cell} ipython3 +%load_ext sql +%sql sqlite:// +``` + +Let's create a sample table: + +```{code-cell} ipython3 +:tags: [hide-output] +%%sql sqlite:// +CREATE TABLE writer (first_name, last_name, year_of_death); +INSERT INTO writer VALUES ('William', 'Shakespeare', 1616); +INSERT INTO writer VALUES ('Bertold', 'Brecht', 1956); +``` + + +## Run Tests on Column + +Use `%sqlcmd test` to run quantitative tests on your dataset. + +For example, to see if all the values in the column birth_year are less than 2000, we can use: + +```{code-cell} ipython3 +%sqlcmd test --table writer --column year_of_death --less-than 2000 +``` + +Because both William Shakespeare and Bertold Brecht died before the year 2000, this command will return True. + +However, if we were to run: + +```{code-cell} ipython3 +:tags: [raises-exception] +%sqlcmd test --table writer --column year_of_death --greater 1700 +``` + +We see that a value that failed our test was William Shakespeare, as he died in 1616. + +We can also pass several comparator arguments to test: + +```{code-cell} ipython3 +:tags: [raises-exception] +%sqlcmd test --table writer --column year_of_death --greater-or-equal 1616 --less-than-or-equal 1956 +``` + +Here, because Shakespeare died in 1616 and Brecht in 1956, our test passes. + +However, if we search for a window between 1800 and 1900: + +```{code-cell} ipython3 +:tags: [raises-exception] +%sqlcmd test --table writer --column year_of_death --greater 1800 --less-than 1900 +``` + +The test fails, returning both Shakespeare and Brecht. + +Currently, 5 different comparator arguments are supported: `greater`, `greater-or-equal`, `less-than`, `less-than-or-equal`, and `no-nulls`. + diff --git a/src/sql/magic_cmd.py b/src/sql/magic_cmd.py index a485b0051..9cab00632 100644 --- a/src/sql/magic_cmd.py +++ b/src/sql/magic_cmd.py @@ -8,6 +8,8 @@ from sqlglot import select, condition from sqlalchemy import text +from prettytable import PrettyTable + try: from traitlets.config.configurable import Configurable except ImportError: @@ -135,6 +137,20 @@ def execute(self, cmd_name="", others="", cell="", local_ns=None): ) args = parser.parse_args(others) + + COMPARATOR_ARGS = [ + args.greater, + args.greater_or_equal, + args.less_than, + args.less_than_or_equal, + ] + + if args.table and not any(COMPARATOR_ARGS): + raise UsageError("Please use a valid comparator.") + + if args.table and any(COMPARATOR_ARGS) and not args.column: + raise UsageError("Please pass a column to test.") + if args.greater and args.greater_or_equal: return ValueError( "You cannot use both greater and greater " @@ -149,11 +165,18 @@ def execute(self, cmd_name="", others="", cell="", local_ns=None): conn = sql.connection.Connection.current.session result_dict = run_each_individually(args, conn) - if len(result_dict.keys()): - print( - "Test failed. Returned are samples of the failures from your data:" + if any(len(rows) > 1 for rows in list(result_dict.values())): + for comparator, rows in result_dict.items(): + if len(rows) > 1: + print(f"\n{comparator}:\n") + _pretty = PrettyTable() + _pretty.field_names = rows[0] + for row in rows[1:]: + _pretty.add_row(row) + print(_pretty) + raise UsageError( + "The above values do not not match your test requirements." ) - return result_dict else: return True @@ -182,45 +205,65 @@ def execute(self, cmd_name="", others="", cell="", local_ns=None): return report +def return_test_results(args, conn, query): + try: + columns = [] + column_data = conn.execute(text(query)).cursor.description + res = conn.execute(text(query)).fetchall() + for column in column_data: + columns.append(column[0]) + res = [columns, *res] + return res + except Exception as e: + if "column" in str(e): + raise UsageError(f"Referenced column '{args.column}' not found!") + + def run_each_individually(args, conn): base_query = select("*").from_(args.table) + storage = {} if args.greater: - where = condition(args.column + ">" + args.greater) + where = condition(args.column + "<=" + args.greater) current_query = base_query.where(where).sql() - res = conn.execute(text(current_query)).fetchone() + res = return_test_results(args, conn, query=current_query) if res is not None: storage["greater"] = res if args.greater_or_equal: - where = condition(args.column + ">=" + args.greater_or_equal) + where = condition(args.column + "<" + args.greater_or_equal) current_query = base_query.where(where).sql() - res = conn.execute(text(current_query)).fetchone() + res = return_test_results(args, conn, query=current_query) + if res is not None: storage["greater_or_equal"] = res + if args.less_than_or_equal: - where = condition(args.column + "<=" + args.less_than_or_equal) + where = condition(args.column + ">" + args.less_than_or_equal) current_query = base_query.where(where).sql() - res = conn.execute(text(current_query)).fetchone() + res = return_test_results(args, conn, query=current_query) + if res is not None: storage["less_than_or_equal"] = res if args.less_than: - where = condition(args.column + "<" + args.less_than) + where = condition(args.column + ">=" + args.less_than) current_query = base_query.where(where).sql() - res = conn.execute(text(current_query)).fetchone() + res = return_test_results(args, conn, query=current_query) + if res is not None: storage["less_than"] = res if args.no_nulls: where = condition("{} is NULL".format(args.column)) current_query = base_query.where(where).sql() - res = conn.execute(text(current_query)).fetchone() + res = return_test_results(args, conn, query=current_query) + if res is not None: storage["null"] = res diff --git a/src/tests/integration/test_generic_db_operations.py b/src/tests/integration/test_generic_db_operations.py index d911e26cf..e19816acc 100644 --- a/src/tests/integration/test_generic_db_operations.py +++ b/src/tests/integration/test_generic_db_operations.py @@ -329,17 +329,28 @@ def test_sqlplot_boxplot(ip_with_dynamic_db, cell, request, test_table_name_dict # ("ip_with_Snowflake"), ], ) -def test_sql_cmd_magic_uno(ip_with_dynamic_db, request, test_table_name_dict): +def test_sql_cmd_magic_uno(ip_with_dynamic_db, request, capsys): ip_with_dynamic_db = request.getfixturevalue(ip_with_dynamic_db) - result = ip_with_dynamic_db.run_cell( - f"%sqlcmd test --table {test_table_name_dict['numbers']}\ - --column numbers_elements --less-than 5 --greater 1" - ).result + ip_with_dynamic_db.run_cell( + """ + %%sql sqlite:// + CREATE TABLE test_numbers (value); + INSERT INTO test_numbers VALUES (0); + INSERT INTO test_numbers VALUES (4); + INSERT INTO test_numbers VALUES (5); + INSERT INTO test_numbers VALUES (6); + """ + ) + + ip_with_dynamic_db.run_cell( + "%sqlcmd test --table test_numbers --column value" " --less-than 5 --greater 1" + ) + + _out = capsys.readouterr() - assert len(result) == 2 - assert "less_than" in result.keys() - assert "greater" in result.keys() + assert "less_than" in _out.out + assert "greater" in _out.out @pytest.mark.parametrize( @@ -359,18 +370,28 @@ def test_sql_cmd_magic_uno(ip_with_dynamic_db, request, test_table_name_dict): # ), ], ) -def test_sql_cmd_magic_dos(ip_with_dynamic_db, request, test_table_name_dict): +def test_sql_cmd_magic_dos(ip_with_dynamic_db, request, capsys): ip_with_dynamic_db = request.getfixturevalue(ip_with_dynamic_db) - result = ip_with_dynamic_db.run_cell( - f"%sqlcmd test --table {test_table_name_dict['numbers']}\ - --column numbers_elements" - " --greater-or-equal 3" - ).result + ip_with_dynamic_db.run_cell( + """ + %%sql sqlite:// + CREATE TABLE test_numbers (value); + INSERT INTO test_numbers VALUES (0); + INSERT INTO test_numbers VALUES (4); + INSERT INTO test_numbers VALUES (5); + INSERT INTO test_numbers VALUES (6); + """ + ) + + ip_with_dynamic_db.run_cell( + "%sqlcmd test --table test_numbers --column value --greater-or-equal 3" + ) + + _out = capsys.readouterr() - assert len(result) == 1 - assert "greater_or_equal" in result.keys() - assert list(result["greater_or_equal"]) == [2, 3] + assert "greater_or_equal" in _out.out + assert "0" in _out.out @pytest.mark.parametrize( diff --git a/src/tests/test_magic_cmd.py b/src/tests/test_magic_cmd.py index e26515105..ef97fe8b7 100644 --- a/src/tests/test_magic_cmd.py +++ b/src/tests/test_magic_cmd.py @@ -209,3 +209,37 @@ def test_table_profile_store(ip, tmp_empty): report = Path("test_report.html") assert report.is_file() + + +@pytest.mark.parametrize( + "cell, error_type, error_message", + [ + ["%sqlcmd test -t test_numbers", UsageError, "Please use a valid comparator."], + [ + "%sqlcmd test --t test_numbers --greater 12", + UsageError, + "Please pass a column to test.", + ], + [ + "%sqlcmd test --table test_numbers --column something --greater 100", + UsageError, + "Referenced column 'something' not found!", + ], + ], +) +def test_test_error(ip, cell, error_type, error_message): + ip.run_cell( + """ + %%sql sqlite:// + CREATE TABLE test_numbers (value); + INSERT INTO test_numbers VALUES (14); + INSERT INTO test_numbers VALUES (13); + INSERT INTO test_numbers VALUES (12); + INSERT INTO test_numbers VALUES (11); + """ + ) + + out = ip.run_cell(cell) + + assert isinstance(out.error_in_exec, error_type) + assert str(out.error_in_exec) == error_message