diff --git a/superset/app.py b/superset/app.py index 5280922a1d1b0..e073f45b02d07 100644 --- a/superset/app.py +++ b/superset/app.py @@ -149,8 +149,8 @@ def init_views(self) -> None: AlertLogModelView, AlertModelView, AlertObservationModelView, - ValidatorInlineView, SQLObserverInlineView, + ValidatorInlineView, ) from superset.views.annotations import ( AnnotationLayerModelView, diff --git a/superset/databases/api.py b/superset/databases/api.py index c3355f538f98f..5a65b84a33ed6 100644 --- a/superset/databases/api.py +++ b/superset/databases/api.py @@ -20,8 +20,15 @@ from flask import g, request, Response from flask_appbuilder.api import expose, protect, rison, safe from flask_appbuilder.models.sqla.interface import SQLAInterface +from flask_babel import gettext as _ from marshmallow import ValidationError -from sqlalchemy.exc import NoSuchTableError, OperationalError, SQLAlchemyError +from sqlalchemy.engine.url import make_url +from sqlalchemy.exc import ( + NoSuchModuleError, + NoSuchTableError, + OperationalError, + SQLAlchemyError, +) from superset import event_logger from superset.constants import RouteMethod @@ -33,8 +40,10 @@ DatabaseDeleteFailedError, DatabaseInvalidError, DatabaseNotFoundError, + DatabaseSecurityUnsafeError, DatabaseUpdateFailedError, ) +from superset.databases.commands.test_connection import TestConnectionDatabaseCommand from superset.databases.commands.update import UpdateDatabaseCommand from superset.databases.dao import DatabaseDAO from superset.databases.decorators import check_datasource_access @@ -44,6 +53,7 @@ DatabasePostSchema, DatabasePutSchema, DatabaseRelatedObjectsResponse, + DatabaseTestConnectionSchema, SchemasResponseSchema, SelectStarResponseSchema, TableMetadataResponseSchema, @@ -65,6 +75,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi): "table_metadata", "select_star", "schemas", + "test_connection", "related_objects", } class_permission_name = "DatabaseView" @@ -343,7 +354,7 @@ def delete(self, pk: int) -> Response: # pylint: disable=arguments-differ @rison(database_schemas_query_schema) @statsd_metrics def schemas(self, pk: int, **kwargs: Any) -> FlaskResponse: - """ Get all schemas from a database + """Get all schemas from a database --- get: description: Get all schemas from a database @@ -400,7 +411,7 @@ def schemas(self, pk: int, **kwargs: Any) -> FlaskResponse: def table_metadata( self, database: Database, table_name: str, schema_name: str ) -> FlaskResponse: - """ Table schema info + """Table schema info --- get: description: Get database table metadata @@ -457,7 +468,7 @@ def table_metadata( def select_star( self, database: Database, table_name: str, schema_name: Optional[str] = None ) -> FlaskResponse: - """ Table schema info + """Table schema info --- get: description: Get database select star for table @@ -506,6 +517,86 @@ def select_star( self.incr_stats("success", self.select_star.__name__) return self.response(200, result=result) + @expose("/test_connection", methods=["POST"]) + @protect() + @safe + @event_logger.log_this + @statsd_metrics + def test_connection( # pylint: disable=too-many-return-statements + self, + ) -> FlaskResponse: + """Tests a database connection + --- + post: + description: >- + Tests a database connection + requestBody: + description: Database schema + required: true + content: + application/json: + schema: + type: object + properties: + encrypted_extra: + type: object + extras: + type: object + name: + type: string + server_cert: + type: string + responses: + 200: + description: Database Test Connection + content: + application/json: + schema: + type: object + properties: + message: + type: string + 400: + $ref: '#/components/responses/400' + 422: + $ref: '#/components/responses/422' + 500: + $ref: '#/components/responses/500' + """ + if not request.is_json: + return self.response_400(message="Request is not JSON") + try: + item = DatabaseTestConnectionSchema().load(request.json) + # This validates custom Schema with custom validations + except ValidationError as error: + return self.response_400(message=error.messages) + try: + TestConnectionDatabaseCommand(g.user, item).run() + return self.response(200, message="OK") + except (NoSuchModuleError, ModuleNotFoundError): + logger.info("Invalid driver") + driver_name = make_url(item.get("sqlalchemy_uri")).drivername + return self.response( + 400, + message=_(f"Could not load database driver: {driver_name}"), + driver_name=driver_name, + ) + except DatabaseSecurityUnsafeError as ex: + return self.response_422(message=ex) + except OperationalError: + logger.warning("Connection failed") + return self.response( + 500, + message=_("Connection failed, please check your connection settings"), + ) + except Exception as ex: # pylint: disable=broad-except + logger.error("Unexpected error %s", type(ex).__name__) + return self.response_400( + message=_( + "Unexpected error occurred, please check your logs for details" + ) + ) + @expose("//related_objects/", methods=["GET"]) @protect() @safe diff --git a/superset/databases/commands/exceptions.py b/superset/databases/commands/exceptions.py index 66a3245b17cb0..51d1660ca73bf 100644 --- a/superset/databases/commands/exceptions.py +++ b/superset/databases/commands/exceptions.py @@ -24,6 +24,7 @@ DeleteFailedError, UpdateFailedError, ) +from superset.security.analytics_db_safety import DBSecurityException class DatabaseInvalidError(CommandInvalidError): @@ -109,3 +110,7 @@ class DatabaseDeleteDatasetsExistFailedError(DeleteFailedError): class DatabaseDeleteFailedError(DeleteFailedError): message = _("Database could not be deleted.") + + +class DatabaseSecurityUnsafeError(DBSecurityException): + message = _("Stopped an unsafe database connection") diff --git a/superset/databases/commands/test_connection.py b/superset/databases/commands/test_connection.py new file mode 100644 index 0000000000000..3bcd5b09237cc --- /dev/null +++ b/superset/databases/commands/test_connection.py @@ -0,0 +1,67 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import logging +from contextlib import closing +from typing import Any, Dict, Optional + +import simplejson as json +from flask_appbuilder.security.sqla.models import User +from sqlalchemy import select + +from superset.commands.base import BaseCommand +from superset.databases.commands.exceptions import DatabaseSecurityUnsafeError +from superset.databases.dao import DatabaseDAO +from superset.models.core import Database +from superset.security.analytics_db_safety import DBSecurityException + +logger = logging.getLogger(__name__) + + +class TestConnectionDatabaseCommand(BaseCommand): + def __init__(self, user: User, data: Dict[str, Any]): + self._actor = user + self._properties = data.copy() + self._model: Optional[Database] = None + + def run(self) -> None: + self.validate() + try: + uri = self._properties.get("sqlalchemy_uri", "") + if self._model and uri == self._model.safe_sqlalchemy_uri(): + uri = self._model.sqlalchemy_uri_decrypted + + database = DatabaseDAO.build_db_for_connection_test( + server_cert=self._properties.get("server_cert", ""), + extra=json.dumps(self._properties.get("extra", {})), + impersonate_user=self._properties.get("impersonate_user", False), + encrypted_extra=json.dumps(self._properties.get("encrypted_extra", {})), + ) + if database is not None: + database.set_sqlalchemy_uri(uri) + database.db_engine_spec.mutate_db_for_connection_test(database) + username = self._actor.username if self._actor is not None else None + engine = database.get_sqla_engine(user_name=username) + with closing(engine.connect()) as conn: + conn.scalar(select([1])) + except DBSecurityException as ex: + logger.warning(ex) + raise DatabaseSecurityUnsafeError() + + def validate(self) -> None: + database_name = self._properties.get("database_name") + if database_name is not None: + self._model = DatabaseDAO.get_database_by_name(database_name) diff --git a/superset/databases/dao.py b/superset/databases/dao.py index 804ac129e4b06..2e89ad0735dfc 100644 --- a/superset/databases/dao.py +++ b/superset/databases/dao.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Any, Dict +from typing import Any, Dict, Optional from superset.dao.base import BaseDAO from superset.databases.filters import DatabaseFilter @@ -45,6 +45,25 @@ def validate_update_uniqueness(database_id: int, database_name: str) -> bool: ) return not db.session.query(database_query.exists()).scalar() + @staticmethod + def get_database_by_name(database_name: str) -> Optional[Database]: + return ( + db.session.query(Database) + .filter(Database.database_name == database_name) + .one_or_none() + ) + + @staticmethod + def build_db_for_connection_test( + server_cert: str, extra: str, impersonate_user: bool, encrypted_extra: str + ) -> Optional[Database]: + return Database( + server_cert=server_cert, + extra=extra, + impersonate_user=impersonate_user, + encrypted_extra=encrypted_extra, + ) + @classmethod def get_related_objects(cls, database_id: int) -> Dict[str, Any]: datasets = cls.find_by_id(database_id).tables diff --git a/superset/databases/schemas.py b/superset/databases/schemas.py index 2d6779df0512e..859eebb9290d6 100644 --- a/superset/databases/schemas.py +++ b/superset/databases/schemas.py @@ -17,6 +17,7 @@ import inspect import json +from flask import current_app from flask_babel import lazy_gettext as _ from marshmallow import fields, Schema from marshmallow.validate import Length, ValidationError @@ -24,7 +25,6 @@ from sqlalchemy.engine.url import make_url from sqlalchemy.exc import ArgumentError -from superset import app from superset.exceptions import CertificateException from superset.utils.core import markdown, parse_ssl_cert @@ -142,7 +142,7 @@ def sqlalchemy_uri_validator(value: str) -> str: ) ] ) - if app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] and value: + if current_app.config.get("PREVENT_UNSAFE_DB_CONNECTIONS", True) and value: if value.startswith("sqlite"): raise ValidationError( [ @@ -291,6 +291,25 @@ class DatabasePutSchema(Schema): ) +class DatabaseTestConnectionSchema(Schema): + database_name = fields.String( + description=database_name_description, allow_none=True, validate=Length(1, 250), + ) + impersonate_user = fields.Boolean(description=impersonate_user_description) + extra = fields.String(description=extra_description, validate=extra_validator) + encrypted_extra = fields.String( + description=encrypted_extra_description, validate=encrypted_extra_validator + ) + server_cert = fields.String( + description=server_cert_description, validate=server_cert_validator + ) + sqlalchemy_uri = fields.String( + description=sqlalchemy_uri_description, + required=True, + validate=[Length(1, 1024), sqlalchemy_uri_validator], + ) + + class TableMetadataOptionsResponseSchema(Schema): deferrable = fields.Bool() initially = fields.Bool() diff --git a/superset/tasks/schedules.py b/superset/tasks/schedules.py index d1fabc9f3e6d4..9643f094f6361 100644 --- a/superset/tasks/schedules.py +++ b/superset/tasks/schedules.py @@ -70,8 +70,8 @@ if TYPE_CHECKING: # pylint: disable=unused-import - from werkzeug.datastructures import TypeConversionDict from flask_appbuilder.security.sqla.models import User + from werkzeug.datastructures import TypeConversionDict # Globals config = app.config diff --git a/superset/views/base_api.py b/superset/views/base_api.py index 6ee2016cf2c8a..458fa83e3876a 100644 --- a/superset/views/base_api.py +++ b/superset/views/base_api.py @@ -91,24 +91,25 @@ class BaseSupersetModelRestApi(ModelRestApi): csrf_exempt = False method_permission_name = { - "get_list": "list", - "get": "show", + "bulk_delete": "delete", + "data": "list", + "delete": "delete", + "distinct": "list", "export": "mulexport", + "get": "show", + "get_list": "list", + "info": "list", "post": "add", "put": "edit", - "delete": "delete", - "bulk_delete": "delete", - "info": "list", - "related": "list", - "distinct": "list", - "thumbnail": "list", "refresh": "edit", - "data": "list", - "viz_types": "list", + "related": "list", "related_objects": "list", - "table_metadata": "list", - "select_star": "list", "schemas": "list", + "select_star": "list", + "table_metadata": "list", + "test_connection": "post", + "thumbnail": "list", + "viz_types": "list", } order_rel_fields: Dict[str, Tuple[str, str]] = {} diff --git a/superset/views/core.py b/superset/views/core.py index a96ce15d2a2e4..1ee501e57a1da 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -1162,7 +1162,7 @@ def testconn( # pylint: disable=too-many-return-statements,no-self-use logger.warning("Stopped an unsafe database connection") return json_error_response(_(str(ex)), 400) except Exception as ex: # pylint: disable=broad-except - logger.error("Unexpected error %s", type(ex).__name__) + logger.warning("Unexpected error %s", type(ex).__name__) return json_error_response( _("Unexpected error occurred, please check your logs for details"), 400 ) diff --git a/tests/databases/api_tests.py b/tests/databases/api_tests.py index 6d82202df690d..e07b7ac8a072f 100644 --- a/tests/databases/api_tests.py +++ b/tests/databases/api_tests.py @@ -21,13 +21,13 @@ import prison from sqlalchemy.sql import func -import tests.test_app from superset import db, security_manager from superset.connectors.sqla.models import SqlaTable from superset.models.core import Database from superset.utils.core import get_example_database, get_main_database from tests.base_tests import SupersetTestCase from tests.fixtures.certificates import ssl_certificate +from tests.test_app import app class TestDatabaseApi(SupersetTestCase): @@ -652,6 +652,97 @@ def test_database_schemas_invalid_query(self): ) self.assertEqual(rv.status_code, 400) + def test_test_connection(self): + """ + Database API: Test test connection + """ + # need to temporarily allow sqlite dbs, teardown will undo this + app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] = False + self.login("admin") + example_db = get_example_database() + # validate that the endpoint works with the password-masked sqlalchemy uri + data = { + "sqlalchemy_uri": example_db.safe_sqlalchemy_uri(), + "database_name": "examples", + "impersonate_user": False, + } + url = f"api/v1/database/test_connection" + rv = self.post_assert_metric(url, data, "test_connection") + self.assertEqual(rv.status_code, 200) + self.assertEqual(rv.headers["Content-Type"], "application/json; charset=utf-8") + + # validate that the endpoint works with the decrypted sqlalchemy uri + data = { + "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted, + "database_name": "examples", + "impersonate_user": False, + } + rv = self.post_assert_metric(url, data, "test_connection") + self.assertEqual(rv.status_code, 200) + self.assertEqual(rv.headers["Content-Type"], "application/json; charset=utf-8") + + def test_test_connection_failed(self): + """ + Database API: Test test connection failed + """ + self.login("admin") + + data = { + "sqlalchemy_uri": "broken://url", + "database_name": "examples", + "impersonate_user": False, + } + url = f"api/v1/database/test_connection" + rv = self.post_assert_metric(url, data, "test_connection") + self.assertEqual(rv.status_code, 400) + self.assertEqual(rv.headers["Content-Type"], "application/json; charset=utf-8") + response = json.loads(rv.data.decode("utf-8")) + expected_response = { + "driver_name": "broken", + "message": "Could not load database driver: broken", + } + self.assertEqual(response, expected_response) + + data = { + "sqlalchemy_uri": "mssql+pymssql://url", + "database_name": "examples", + "impersonate_user": False, + } + rv = self.post_assert_metric(url, data, "test_connection") + self.assertEqual(rv.status_code, 400) + self.assertEqual(rv.headers["Content-Type"], "application/json; charset=utf-8") + response = json.loads(rv.data.decode("utf-8")) + expected_response = { + "driver_name": "mssql+pymssql", + "message": "Could not load database driver: mssql+pymssql", + } + self.assertEqual(response, expected_response) + + def test_test_connection_unsafe_uri(self): + """ + Database API: Test test connection with unsafe uri + """ + self.login("admin") + + app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] = True + data = { + "sqlalchemy_uri": "sqlite:///home/superset/unsafe.db", + "database_name": "unsafe", + "impersonate_user": False, + } + url = f"api/v1/database/test_connection" + rv = self.post_assert_metric(url, data, "test_connection") + self.assertEqual(rv.status_code, 400) + response = json.loads(rv.data.decode("utf-8")) + expected_response = { + "message": { + "sqlalchemy_uri": [ + "SQLite database cannot be used as a data source for security reasons." + ] + } + } + self.assertEqual(response, expected_response) + def test_get_database_related_objects(self): """ Database API: Test get chart and dashboard count related to a database