Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

fix: add disallowed query params for engines specs #23217

Merged
merged 6 commits into from
Feb 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
# This set will give the keywords for data limit statements
# to consider for the engines with TOP SQL parsing
top_keywords: Set[str] = {"TOP"}
# A set of disallowed connection query parameters
disallow_uri_query_params: Set[str] = set()

force_column_alias_quotes = False
arraysize = 0
Expand Down Expand Up @@ -1724,6 +1726,19 @@ def get_public_information(cls) -> Dict[str, Any]:
"disable_ssh_tunneling": cls.disable_ssh_tunneling,
}

@classmethod
def validate_database_uri(cls, sqlalchemy_uri: URL) -> None:
"""
Validates a database SQLAlchemy URI per engine spec.
Use this to implement a final validation for unwanted connection configuration

:param sqlalchemy_uri:
"""
if existing_disallowed := cls.disallow_uri_query_params.intersection(
sqlalchemy_uri.query
):
raise ValueError(f"Forbidden query parameter(s): {existing_disallowed}")


# schema for adding a database by providing parameters instead of the
# full SQLAlchemy URI
Expand Down
1 change: 1 addition & 0 deletions superset/db_engine_specs/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin):
{},
),
}
disallow_uri_query_params = {"local_infile"}

@classmethod
def convert_dttm(
Expand Down
2 changes: 2 additions & 0 deletions superset/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,8 @@ def _get_sqla_engine(
sqlalchemy_url = make_url_safe(
sqlalchemy_uri if sqlalchemy_uri else self.sqlalchemy_uri_decrypted
)
self.db_engine_spec.validate_database_uri(sqlalchemy_url)

sqlalchemy_url = self.db_engine_spec.adjust_database_uri(sqlalchemy_url, schema)
effective_username = self.get_effective_user(sqlalchemy_url)
# If using MySQL or Presto for example, will set url.username
Expand Down
20 changes: 20 additions & 0 deletions tests/unit_tests/db_engine_specs/test_mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
TINYINT,
TINYTEXT,
)
from sqlalchemy.engine.url import make_url

from superset.utils.core import GenericDataType
from tests.unit_tests.db_engine_specs.utils import (
Expand Down Expand Up @@ -99,6 +100,25 @@ def test_convert_dttm(
assert_convert_dttm(spec, target_type, expected_result, dttm)


@pytest.mark.parametrize(
"sqlalchemy_uri,error",
[
("mysql://user:password@host/db1?local_infile=1", True),
("mysql://user:password@host/db1?local_infile=0", True),
("mysql://user:password@host/db1", False),
],
)
def test_validate_database_uri(sqlalchemy_uri: str, error: bool) -> None:
from superset.db_engine_specs.mysql import MySQLEngineSpec

url = make_url(sqlalchemy_uri)
if error:
with pytest.raises(ValueError):
MySQLEngineSpec.validate_database_uri(url)
return
MySQLEngineSpec.validate_database_uri(url)


@patch("sqlalchemy.engine.Engine.connect")
def test_get_cancel_query_id(engine_mock: Mock) -> None:
from superset.db_engine_specs.mysql import MySQLEngineSpec
Expand Down