diff --git a/.github/workflows/formatter.yml b/.github/workflows/formatter.yml index c1860fdf..384c34c3 100644 --- a/.github/workflows/formatter.yml +++ b/.github/workflows/formatter.yml @@ -21,13 +21,25 @@ jobs: uses: actions/checkout@v3 if: github.event_name == 'workflow_dispatch' - - name: Check files using the ruff formatter + # This is used for forked PRs as write permissions are required to format files + - name: Run and commit changes with `ruff format .` locally on your forked branch to fix errors if they appear + if: ${{ github.event.pull_request.head.repo.fork == true }} + uses: chartboost/ruff-action@v1 + id: ruff_formatter_suggestions + with: + args: format --diff + + # This only runs if the PR is NOT from a forked repo + - name: Format files using ruff + if: ${{ github.event.pull_request.head.repo.fork == false }} uses: chartboost/ruff-action@v1 id: ruff_formatter with: args: format + # This only runs if the PR is NOT from a forked repo - name: Auto commit ruff formatting + if: ${{ github.event.pull_request.head.repo.fork == false }} uses: stefanzweifel/git-auto-commit-action@v5 with: - commit_message: 'style fixes by ruff' \ No newline at end of file + commit_message: 'style fixes by ruff' diff --git a/data_diff/databases/mysql.py b/data_diff/databases/mysql.py index 647388f2..2b1e810f 100644 --- a/data_diff/databases/mysql.py +++ b/data_diff/databases/mysql.py @@ -1,4 +1,4 @@ -from typing import Any, ClassVar, Dict, Type +from typing import Any, ClassVar, Dict, Type, Union import attrs @@ -20,6 +20,7 @@ import_helper, ConnectError, BaseDialect, + ThreadLocalInterpreter, ) from data_diff.databases.base import ( MD5_HEXDIGITS, @@ -148,3 +149,11 @@ def create_connection(self): elif e.errno == mysql.errorcode.ER_BAD_DB_ERROR: raise ConnectError("Database does not exist") from e raise ConnectError(*e.args) from e + + def _query_in_worker(self, sql_code: Union[str, ThreadLocalInterpreter]): + "This method runs in a worker thread" + if self._init_error: + raise self._init_error + if not self.thread_local.conn.is_connected(): + self.thread_local.conn.ping(reconnect=True, attempts=3, delay=5) + return self._query_conn(self.thread_local.conn, sql_code) diff --git a/data_diff/databases/postgresql.py b/data_diff/databases/postgresql.py index d29fa0eb..e9a38cde 100644 --- a/data_diff/databases/postgresql.py +++ b/data_diff/databases/postgresql.py @@ -1,5 +1,5 @@ from typing import Any, ClassVar, Dict, List, Type - +from urllib.parse import unquote import attrs from data_diff.abcs.database_types import ( @@ -168,6 +168,7 @@ def create_connection(self): pg = import_postgresql() try: + self._args["password"] = unquote(self._args["password"]) self._conn = pg.connect( **self._args, keepalives=1, keepalives_idle=5, keepalives_interval=2, keepalives_count=2 ) diff --git a/data_diff_demo b/data_diff_demo new file mode 160000 index 00000000..d0784e8d --- /dev/null +++ b/data_diff_demo @@ -0,0 +1 @@ +Subproject commit d0784e8de9fc7958f91a599fa454be4f8b09c60d diff --git a/datafold-demo-sung b/datafold-demo-sung new file mode 160000 index 00000000..6ebfb06d --- /dev/null +++ b/datafold-demo-sung @@ -0,0 +1 @@ +Subproject commit 6ebfb06d1e0937309384cdb4955e6dbd23387256 diff --git a/tests/test_postgresql.py b/tests/test_postgresql.py index b5e9fa10..ed1baecf 100644 --- a/tests/test_postgresql.py +++ b/tests/test_postgresql.py @@ -1,9 +1,11 @@ import unittest +from urllib.parse import quote from data_diff.queries.api import table, commit from data_diff import TableSegment, HashDiffer from data_diff import databases as db -from tests.common import get_conn, random_table_suffix +from tests.common import get_conn, random_table_suffix, connect +from data_diff import connect_to_table class TestUUID(unittest.TestCase): @@ -113,3 +115,41 @@ def test_100_fields(self): id_ = diff[0][1][0] result = (id_,) + tuple("1" for x in range(100)) self.assertEqual(diff, [("-", result)]) + + +class TestSpecialCharacterPassword(unittest.TestCase): + def setUp(self) -> None: + self.connection = get_conn(db.PostgreSQL) + + table_suffix = random_table_suffix() + + self.table_name = f"table{table_suffix}" + self.table = table(self.table_name) + + def test_special_char_password(self): + password = "passw!!!@rd" + # Setup user with special character '@' in password + self.connection.query("DROP USER IF EXISTS test;", None) + self.connection.query(f"CREATE USER test WITH PASSWORD '{password}';", None) + + password_quoted = quote(password) + db_config = { + "driver": "postgresql", + "host": "localhost", + "port": 5432, + "dbname": "postgres", + "user": "test", + "password": password_quoted, + } + + # verify pythonic connection method + connect_to_table( + db_config, + self.table_name, + ) + + # verify connection method with URL string unquoted after it's verified + db_url = f"postgresql://{db_config['user']}:{db_config['password']}@{db_config['host']}:{db_config['port']}/{db_config['dbname']}" + + connection_verified = connect(db_url) + assert connection_verified._args.get("password") == password